Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d3560b71
Unverified
Commit
d3560b71
authored
Apr 02, 2020
by
Mufei Li
Committed by
GitHub
Apr 02, 2020
Browse files
[DGL-LifeSci] Documentation (#1414)
* Update * Update * Update
parent
7e0893e6
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
708 additions
and
58 deletions
+708
-58
apps/life_sci/python/dgllife/data/tox21.py
apps/life_sci/python/dgllife/data/tox21.py
+11
-3
apps/life_sci/python/dgllife/data/uspto.py
apps/life_sci/python/dgllife/data/uspto.py
+4
-5
apps/life_sci/python/dgllife/model/pretrain.py
apps/life_sci/python/dgllife/model/pretrain.py
+14
-9
apps/life_sci/python/dgllife/utils/early_stop.py
apps/life_sci/python/dgllife/utils/early_stop.py
+34
-1
apps/life_sci/python/dgllife/utils/eval.py
apps/life_sci/python/dgllife/utils/eval.py
+41
-15
apps/life_sci/python/dgllife/utils/featurizers.py
apps/life_sci/python/dgllife/utils/featurizers.py
+222
-7
apps/life_sci/python/dgllife/utils/mol_to_graph.py
apps/life_sci/python/dgllife/utils/mol_to_graph.py
+329
-0
apps/life_sci/python/dgllife/utils/rdkit_utils.py
apps/life_sci/python/dgllife/utils/rdkit_utils.py
+30
-6
apps/life_sci/python/dgllife/utils/splitters.py
apps/life_sci/python/dgllife/utils/splitters.py
+22
-11
docs/source/install/index.rst
docs/source/install/index.rst
+1
-1
No files found.
apps/life_sci/python/dgllife/data/tox21.py
View file @
d3560b71
...
@@ -21,7 +21,6 @@ class Tox21(MoleculeCSVDataset):
...
@@ -21,7 +21,6 @@ class Tox21(MoleculeCSVDataset):
A common issue for multi-task prediction is that some datapoints are not labeled for
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction,
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
...
@@ -87,9 +86,18 @@ class Tox21(MoleculeCSVDataset):
...
@@ -87,9 +86,18 @@ class Tox21(MoleculeCSVDataset):
def
task_pos_weights
(
self
):
def
task_pos_weights
(
self
):
"""Get weights for positive samples on each task
"""Get weights for positive samples on each task
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
Returns
Returns
-------
-------
numpy.ndarray
Tensor of dtype float32 and shape (T)
numpy array gives the w
eight of positive samples on all tasks
W
eight of positive samples on all tasks
"""
"""
return
self
.
_task_pos_weights
return
self
.
_task_pos_weights
apps/life_sci/python/dgllife/data/uspto.py
View file @
d3560b71
...
@@ -363,15 +363,15 @@ class WLNReactionDataset(object):
...
@@ -363,15 +363,15 @@ class WLNReactionDataset(object):
Returns
Returns
-------
-------
str
str
Reaction
Reaction
.
str
str
Graph edits for the reaction
Graph edits for the reaction
rdkit.Chem.rdchem.Mol
rdkit.Chem.rdchem.Mol
RDKit molecule instance
RDKit molecule instance
for reactants
DGLGraph
DGLGraph
DGLGraph for the ith molecular graph
DGLGraph for the ith molecular graph
of reactants
DGLGraph
DGLGraph
Complete DGLGraph, which will be needed for predicting
Complete DGLGraph
for reactants
, which will be needed for predicting
scores between each pair of atoms
scores between each pair of atoms
float32 tensor of shape (V^2, 10)
float32 tensor of shape (V^2, 10)
Features for each pair of atoms.
Features for each pair of atoms.
...
@@ -477,7 +477,6 @@ class USPTO(WLNReactionDataset):
...
@@ -477,7 +477,6 @@ class USPTO(WLNReactionDataset):
-------
-------
str
str
* 'full' for the complete dataset
* 'train' for the training set
* 'train' for the training set
* 'val' for the validation set
* 'val' for the validation set
* 'test' for the test set
* 'test' for the test set
...
...
apps/life_sci/python/dgllife/model/pretrain.py
View file @
d3560b71
...
@@ -68,15 +68,20 @@ def load_pretrained(model_name, log=True):
...
@@ -68,15 +68,20 @@ def load_pretrained(model_name, log=True):
model_name : str
model_name : str
Currently supported options include
Currently supported options include
* ``'GCN_Tox21'``
* ``'GCN_Tox21'``: A GCN-based model for molecular property prediction on Tox21
* ``'GAT_Tox21'``
* ``'GAT_Tox21'``: A GAT-based model for molecular property prediction on Tox21
* ``'AttentiveFP_Aromaticity'``
* ``'AttentiveFP_Aromaticity'``: An AttentiveFP model for predicting number of
* ``'DGMG_ChEMBL_canonical'``
aromatic atoms on a subset of Pubmed
* ``'DGMG_ChEMBL_random'``
* ``'DGMG_ChEMBL_canonical'``: A DGMG model trained on ChEMBL with a canonical
* ``'DGMG_ZINC_canonical'``
atom order
* ``'DGMG_ZINC_random'``
* ``'DGMG_ChEMBL_random'``: A DGMG model trained on ChEMBL for molecule generation
* ``'JTNN_ZINC'``
with a random atom order
* ``'wln_center_uspto'``
* ``'DGMG_ZINC_canonical'``: A DGMG model trained on ZINC for molecule generation
with a canonical atom order
* ``'DGMG_ZINC_random'``: A DGMG model pre-trained on ZINC for molecule generation
with a random atom order
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
log : bool
log : bool
Whether to print progress for model loading
Whether to print progress for model loading
...
...
apps/life_sci/python/dgllife/utils/early_stop.py
View file @
d3560b71
...
@@ -22,7 +22,40 @@ class EarlyStopping(object):
...
@@ -22,7 +22,40 @@ class EarlyStopping(object):
The early stopping will happen if we do not observe performance
The early stopping will happen if we do not observe performance
improvement for ``patience`` consecutive epochs.
improvement for ``patience`` consecutive epochs.
filename : str or None
filename : str or None
Filename for storing the model checkpoint
Filename for storing the model checkpoint. If not specified,
we will automatically generate a file starting with ``early_stop``
based on the current time.
Examples
--------
Below gives a demo for a fake training process.
>>> import torch
>>> import torch.nn as nn
>>> from torch.nn import MSELoss
>>> from torch.optim import Adam
>>> from dgllife.utils import EarlyStopping
>>> model = nn.Linear(1, 1)
>>> criterion = MSELoss()
>>> # For MSE, the lower, the better
>>> stopper = EarlyStopping(mode='lower', filename='test.pth')
>>> optimizer = Adam(params=model.parameters(), lr=1e-3)
>>> for epoch in range(1000):
>>> x = torch.randn(1, 1) # Fake input
>>> y = torch.randn(1, 1) # Fake label
>>> pred = model(x)
>>> loss = criterion(y, pred)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()
>>> early_stop = stopper.step(loss.detach().data, model)
>>> if early_stop:
>>> break
>>> # Load the final parameters saved by the model
>>> stopper.load_checkpoint(model)
"""
"""
def
__init__
(
self
,
mode
=
'higher'
,
patience
=
10
,
filename
=
None
):
def
__init__
(
self
,
mode
=
'higher'
,
patience
=
10
,
filename
=
None
):
if
filename
is
None
:
if
filename
is
None
:
...
...
apps/life_sci/python/dgllife/utils/eval.py
View file @
d3560b71
...
@@ -19,18 +19,44 @@ class Meter(object):
...
@@ -19,18 +19,44 @@ class Meter(object):
Currently we support evaluation with 4 metrics:
Currently we support evaluation with 4 metrics:
* pearson r2
*
``
pearson r2
``
* mae
*
``
mae
``
* rmse
*
``
rmse
``
* roc auc score
*
``
roc auc score
``
Parameters
Parameters
----------
----------
mean : torch.float32 tensor of shape (T) or None.
mean : torch.float32 tensor of shape (T) or None.
Mean of existing training labels across tasks if not None. T for the number of tasks.
Mean of existing training labels across tasks if not ``None``. ``T`` for the
Default to None.
number of tasks. Default to ``None`` and we assume no label normalization has been
performed.
std : torch.float32 tensor of shape (T)
std : torch.float32 tensor of shape (T)
Std of existing training labels across tasks if not None.
Std of existing training labels across tasks if not ``None``. Default to ``None``
and we assume no label normalization has been performed.
Examples
--------
Below gives a demo for a fake evaluation epoch.
>>> import torch
>>> from dgllife.utils import Meter
>>> meter = Meter()
>>> # Simulate 10 fake mini-batches
>>> for batch_id in range(10):
>>> batch_label = torch.randn(3, 3)
>>> batch_pred = torch.randn(3, 3)
>>> meter.update(batch_pred, batch_label)
>>> # Get MAE for all tasks
>>> print(meter.compute_metric('mae'))
[1.1325558423995972, 1.0543707609176636, 1.094650149345398]
>>> # Get MAE averaged over all tasks
>>> print(meter.compute_metric('mae', reduction='mean'))
1.0938589175542195
>>> # Get the sum of MAE over all tasks
>>> print(meter.compute_metric('mae', reduction='sum'))
3.2815767526626587
"""
"""
def
__init__
(
self
,
mean
=
None
,
std
=
None
):
def
__init__
(
self
,
mean
=
None
,
std
=
None
):
self
.
mask
=
[]
self
.
mask
=
[]
...
@@ -50,13 +76,13 @@ class Meter(object):
...
@@ -50,13 +76,13 @@ class Meter(object):
Parameters
Parameters
----------
----------
y_pred : float32 tensor
y_pred : float32 tensor
Predicted labels with shape (B, T),
Predicted labels with shape
``
(B, T)
``
,
B
for number of graphs in the batch and
T
for the number of tasks
``B``
for number of graphs in the batch and
``T``
for the number of tasks
y_true : float32 tensor
y_true : float32 tensor
Ground truth labels with shape (B, T)
Ground truth labels with shape
``
(B, T)
``
mask : None or float32 tensor
mask : None or float32 tensor
Binary mask indicating the existence of ground truth labels with
Binary mask indicating the existence of ground truth labels with
shape (B, T). If None, we assume that all labels exist and create
shape
``
(B, T)
``
. If None, we assume that all labels exist and create
a one-tensor for placeholder.
a one-tensor for placeholder.
"""
"""
self
.
y_pred
.
append
(
y_pred
.
detach
().
cpu
())
self
.
y_pred
.
append
(
y_pred
.
detach
().
cpu
())
...
@@ -237,10 +263,10 @@ class Meter(object):
...
@@ -237,10 +263,10 @@ class Meter(object):
----------
----------
metric_name : str
metric_name : str
* 'r2': compute squared Pearson correlation coefficient
*
``
'r2'
``
: compute squared Pearson correlation coefficient
* 'mae': compute mean absolute error
*
``
'mae'
``
: compute mean absolute error
* 'rmse': compute root mean square error
*
``
'rmse'
``
: compute root mean square error
* 'roc_auc_score': compute roc-auc score
*
``
'roc_auc_score'
``
: compute roc-auc score
reduction : 'none' or 'mean' or 'sum'
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Controls the form of scores for all tasks
...
...
apps/life_sci/python/dgllife/utils/featurizers.py
View file @
d3560b71
...
@@ -64,6 +64,16 @@ def one_hot_encoding(x, allowable_set, encode_unknown=False):
...
@@ -64,6 +64,16 @@ def one_hot_encoding(x, allowable_set, encode_unknown=False):
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
and ``len(allowable_set) + 1`` otherwise.
Examples
--------
>>> from dgllife.utils import one_hot_encoding
>>> one_hot_encoding('C', ['C', 'O'])
[True, False]
>>> one_hot_encoding('S', ['C', 'O'])
[False, False]
>>> one_hot_encoding('S', ['C', 'O'], encode_unknown=True)
[False, False, True]
"""
"""
if
encode_unknown
and
(
allowable_set
[
-
1
]
is
not
None
):
if
encode_unknown
and
(
allowable_set
[
-
1
]
is
not
None
):
allowable_set
.
append
(
None
)
allowable_set
.
append
(
None
)
...
@@ -98,6 +108,12 @@ def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -98,6 +108,12 @@ def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atomic_number_one_hot
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
allowable_set
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
...
@@ -123,6 +139,12 @@ def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -123,6 +139,12 @@ def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atom_type_one_hot
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
1
,
101
))
allowable_set
=
list
(
range
(
1
,
101
))
...
@@ -140,6 +162,11 @@ def atomic_number(atom):
...
@@ -140,6 +162,11 @@ def atomic_number(atom):
-------
-------
list
list
List containing one int only.
List containing one int only.
See Also
--------
atomic_number_one_hot
atom_type_one_hot
"""
"""
return
[
atom
.
GetAtomicNum
()]
return
[
atom
.
GetAtomicNum
()]
...
@@ -166,6 +193,9 @@ def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -166,6 +193,9 @@ def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
See Also
See Also
--------
--------
one_hot_encoding
atom_degree
atom_total_degree
atom_total_degree_one_hot
atom_total_degree_one_hot
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
...
@@ -190,7 +220,9 @@ def atom_degree(atom):
...
@@ -190,7 +220,9 @@ def atom_degree(atom):
See Also
See Also
--------
--------
atom_degree_one_hot
atom_total_degree
atom_total_degree
atom_total_degree_one_hot
"""
"""
return
[
atom
.
GetDegree
()]
return
[
atom
.
GetDegree
()]
...
@@ -209,7 +241,10 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -209,7 +241,10 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
See Also
See Also
--------
--------
one_hot_encoding
atom_degree
atom_degree_one_hot
atom_degree_one_hot
atom_total_degree
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
6
))
allowable_set
=
list
(
range
(
6
))
...
@@ -218,14 +253,16 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -218,14 +253,16 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
def
atom_total_degree
(
atom
):
def
atom_total_degree
(
atom
):
"""The degree of an atom including Hs.
"""The degree of an atom including Hs.
See Also
--------
atom_degree
Returns
Returns
-------
-------
list
list
List containing one int only.
List containing one int only.
See Also
--------
atom_total_degree_one_hot
atom_degree
atom_degree_one_hot
"""
"""
return
[
atom
.
GetTotalDegree
()]
return
[
atom
.
GetTotalDegree
()]
...
@@ -246,6 +283,11 @@ def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
...
@@ -246,6 +283,11 @@ def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_explicit_valence
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
1
,
7
))
allowable_set
=
list
(
range
(
1
,
7
))
...
@@ -263,6 +305,10 @@ def atom_explicit_valence(atom):
...
@@ -263,6 +305,10 @@ def atom_explicit_valence(atom):
-------
-------
list
list
List containing one int only.
List containing one int only.
See Also
--------
atom_explicit_valence_one_hot
"""
"""
return
[
atom
.
GetExplicitValence
()]
return
[
atom
.
GetExplicitValence
()]
...
@@ -283,6 +329,10 @@ def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
...
@@ -283,6 +329,10 @@ def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
atom_implicit_valence
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
7
))
allowable_set
=
list
(
range
(
7
))
...
@@ -300,6 +350,10 @@ def atom_implicit_valence(atom):
...
@@ -300,6 +350,10 @@ def atom_implicit_valence(atom):
------
------
list
list
List containing one int only.
List containing one int only.
See Also
--------
atom_implicit_valence_one_hot
"""
"""
return
[
atom
.
GetImplicitValence
()]
return
[
atom
.
GetImplicitValence
()]
...
@@ -323,6 +377,10 @@ def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -323,6 +377,10 @@ def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
allowable_set
=
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
...
@@ -349,6 +407,11 @@ def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -349,6 +407,11 @@ def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_total_num_H
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
allowable_set
=
list
(
range
(
5
))
...
@@ -366,6 +429,10 @@ def atom_total_num_H(atom):
...
@@ -366,6 +429,10 @@ def atom_total_num_H(atom):
-------
-------
list
list
List containing one int only.
List containing one int only.
See Also
--------
atom_total_num_H_one_hot
"""
"""
return
[
atom
.
GetTotalNumHs
()]
return
[
atom
.
GetTotalNumHs
()]
...
@@ -386,6 +453,11 @@ def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -386,6 +453,11 @@ def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_formal_charge
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
-
2
,
3
))
allowable_set
=
list
(
range
(
-
2
,
3
))
...
@@ -403,6 +475,10 @@ def atom_formal_charge(atom):
...
@@ -403,6 +475,10 @@ def atom_formal_charge(atom):
-------
-------
list
list
List containing one int only.
List containing one int only.
See Also
--------
atom_formal_charge_one_hot
"""
"""
return
[
atom
.
GetFormalCharge
()]
return
[
atom
.
GetFormalCharge
()]
...
@@ -423,6 +499,11 @@ def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=
...
@@ -423,6 +499,11 @@ def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_num_radical_electrons
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
allowable_set
=
list
(
range
(
5
))
...
@@ -440,6 +521,10 @@ def atom_num_radical_electrons(atom):
...
@@ -440,6 +521,10 @@ def atom_num_radical_electrons(atom):
-------
-------
list
list
List containing one int only.
List containing one int only.
See Also
--------
atom_num_radical_electrons_one_hot
"""
"""
return
[
atom
.
GetNumRadicalElectrons
()]
return
[
atom
.
GetNumRadicalElectrons
()]
...
@@ -460,6 +545,11 @@ def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -460,6 +545,11 @@ def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_aromatic
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
allowable_set
=
[
False
,
True
]
...
@@ -477,6 +567,10 @@ def atom_is_aromatic(atom):
...
@@ -477,6 +567,10 @@ def atom_is_aromatic(atom):
-------
-------
list
list
List containing one bool only.
List containing one bool only.
See Also
--------
atom_is_aromatic_one_hot
"""
"""
return
[
atom
.
GetIsAromatic
()]
return
[
atom
.
GetIsAromatic
()]
...
@@ -497,6 +591,11 @@ def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -497,6 +591,11 @@ def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_in_ring
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
allowable_set
=
[
False
,
True
]
...
@@ -514,6 +613,10 @@ def atom_is_in_ring(atom):
...
@@ -514,6 +613,10 @@ def atom_is_in_ring(atom):
-------
-------
list
list
List containing one bool only.
List containing one bool only.
See Also
--------
atom_is_in_ring_one_hot
"""
"""
return
[
atom
.
IsInRing
()]
return
[
atom
.
IsInRing
()]
...
@@ -529,6 +632,10 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
...
@@ -529,6 +632,10 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
See Also
--------
one_hot_encoding
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
,
allowable_set
=
[
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
,
...
@@ -589,7 +696,8 @@ class BaseAtomFeaturizer(object):
...
@@ -589,7 +696,8 @@ class BaseAtomFeaturizer(object):
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
**We assume the resulting DGLGraph will not contain any virtual nodes and a node i in the
graph corresponds to exactly atom i in the molecule.**
Parameters
Parameters
----------
----------
...
@@ -603,7 +711,7 @@ class BaseAtomFeaturizer(object):
...
@@ -603,7 +711,7 @@ class BaseAtomFeaturizer(object):
Examples
Examples
--------
--------
>>> from dgl
.data.dgllife
import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from dgl
life.utils
import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> mol = Chem.MolFromSmiles('CCO')
...
@@ -615,6 +723,16 @@ class BaseAtomFeaturizer(object):
...
@@ -615,6 +723,16 @@ class BaseAtomFeaturizer(object):
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
>>> # Get feature size for atom mass
>>> print(atom_featurizer.feat_size('mass'))
1
>>> # Get feature size for atom degree
>>> print(atom_featurizer.feat_size('degree'))
11
See Also
--------
CanonicalAtomFeaturizer
"""
"""
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
self
.
featurizer_funcs
=
featurizer_funcs
...
@@ -701,6 +819,38 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
...
@@ -701,6 +819,38 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
----------
----------
atom_data_field : str
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
Name for storing atom features in DGLGraphs, default to be 'h'.
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import CanonicalAtomFeaturizer
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')
>>> atom_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]])}
>>> # Get feature size for nodes
>>> print(atom_featurizer.feat_size('feat'))
74
See Also
--------
BaseAtomFeaturizer
"""
"""
def
__init__
(
self
,
atom_data_field
=
'h'
):
def
__init__
(
self
,
atom_data_field
=
'h'
):
super
(
CanonicalAtomFeaturizer
,
self
).
__init__
(
super
(
CanonicalAtomFeaturizer
,
self
).
__init__
(
...
@@ -734,6 +884,10 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
...
@@ -734,6 +884,10 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
allowable_set
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
...
@@ -744,6 +898,7 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
...
@@ -744,6 +898,7 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
def
bond_is_conjugated_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
def
bond_is_conjugated_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the bond is conjugated.
"""One hot encoding for whether the bond is conjugated.
Parameters
Parameters
----------
----------
bond : rdkit.Chem.rdchem.Bond
bond : rdkit.Chem.rdchem.Bond
...
@@ -753,10 +908,16 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
...
@@ -753,10 +908,16 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool
encode_unknown : bool
If True, map inputs not in the allowable set to the
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
additional last element. (Default: False)
Returns
Returns
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_conjugated
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
allowable_set
=
[
False
,
True
]
...
@@ -764,19 +925,26 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
...
@@ -764,19 +925,26 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
def
bond_is_conjugated
(
bond
):
def
bond_is_conjugated
(
bond
):
"""Get whether the bond is conjugated.
"""Get whether the bond is conjugated.
Parameters
Parameters
----------
----------
bond : rdkit.Chem.rdchem.Bond
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
RDKit bond instance.
Returns
Returns
-------
-------
list
list
List containing one bool only.
List containing one bool only.
See Also
--------
bond_is_conjugated_one_hot
"""
"""
return
[
bond
.
GetIsConjugated
()]
return
[
bond
.
GetIsConjugated
()]
def
bond_is_in_ring_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
def
bond_is_in_ring_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the bond is in a ring of any size.
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
Parameters
----------
----------
bond : rdkit.Chem.rdchem.Bond
bond : rdkit.Chem.rdchem.Bond
...
@@ -786,10 +954,16 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
...
@@ -786,10 +954,16 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool
encode_unknown : bool
If True, map inputs not in the allowable set to the
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
additional last element. (Default: False)
Returns
Returns
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_in_ring
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
allowable_set
=
[
False
,
True
]
...
@@ -797,19 +971,26 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
...
@@ -797,19 +971,26 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
def
bond_is_in_ring
(
bond
):
def
bond_is_in_ring
(
bond
):
"""Get whether the bond is in a ring of any size.
"""Get whether the bond is in a ring of any size.
Parameters
Parameters
----------
----------
bond : rdkit.Chem.rdchem.Bond
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
RDKit bond instance.
Returns
Returns
-------
-------
list
list
List containing one bool only.
List containing one bool only.
See Also
--------
bond_is_in_ring_one_hot
"""
"""
return
[
bond
.
IsInRing
()]
return
[
bond
.
IsInRing
()]
def
bond_stereo_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
def
bond_stereo_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the stereo configuration of a bond.
"""One hot encoding for the stereo configuration of a bond.
Parameters
Parameters
----------
----------
bond : rdkit.Chem.rdchem.Bond
bond : rdkit.Chem.rdchem.Bond
...
@@ -822,10 +1003,15 @@ def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
...
@@ -822,10 +1003,15 @@ def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool
encode_unknown : bool
If True, map inputs not in the allowable set to the
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
additional last element. (Default: False)
Returns
Returns
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
"""
if
allowable_set
is
None
:
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondStereo
.
STEREONONE
,
allowable_set
=
[
Chem
.
rdchem
.
BondStereo
.
STEREONONE
,
...
@@ -858,7 +1044,7 @@ class BaseBondFeaturizer(object):
...
@@ -858,7 +1044,7 @@ class BaseBondFeaturizer(object):
Examples
Examples
--------
--------
>>> from dgl
.data.dgllife
import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from dgl
life.utils
import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> mol = Chem.MolFromSmiles('CCO')
...
@@ -869,6 +1055,15 @@ class BaseBondFeaturizer(object):
...
@@ -869,6 +1055,15 @@ class BaseBondFeaturizer(object):
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
[1., 0., 0., 0.]]),
'ring': tensor([[0.], [0.], [0.], [0.]])}
'ring': tensor([[0.], [0.], [0.], [0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
4
>>> bond_featurizer.feat_size('ring')
1
See Also
--------
CanonicalBondFeaturizer
"""
"""
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
self
.
featurizer_funcs
=
featurizer_funcs
...
@@ -941,6 +1136,26 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
...
@@ -941,6 +1136,26 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
self loops.**
Examples
--------
>>> from dgllife.utils import CanonicalBondFeaturizer
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = CanonicalBondFeaturizer(bond_data_field='feat')
>>> bond_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
12
See Also
--------
BaseBondFeaturizer
"""
"""
def
__init__
(
self
,
bond_data_field
=
'e'
):
def
__init__
(
self
,
bond_data_field
=
'e'
):
super
(
CanonicalBondFeaturizer
,
self
).
__init__
(
super
(
CanonicalBondFeaturizer
,
self
).
__init__
(
...
...
apps/life_sci/python/dgllife/utils/mol_to_graph.py
View file @
d3560b71
...
@@ -21,6 +21,9 @@ __all__ = ['mol_to_graph',
...
@@ -21,6 +21,9 @@ __all__ = ['mol_to_graph',
def
mol_to_graph
(
mol
,
graph_constructor
,
node_featurizer
,
edge_featurizer
,
canonical_atom_order
):
def
mol_to_graph
(
mol
,
graph_constructor
,
node_featurizer
,
edge_featurizer
,
canonical_atom_order
):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
This function can be used to construct any arbitrary ``DGLGraph`` from an
RDKit molecule instance.
Parameters
Parameters
----------
----------
mol : rdkit.Chem.rdchem.Mol
mol : rdkit.Chem.rdchem.Mol
...
@@ -41,6 +44,12 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canon
...
@@ -41,6 +44,12 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canon
-------
-------
g : DGLGraph
g : DGLGraph
Converted DGLGraph for the molecule
Converted DGLGraph for the molecule
See Also
--------
mol_to_bigraph
mol_to_complete_graph
mol_to_nearest_neighbor_graph
"""
"""
if
canonical_atom_order
:
if
canonical_atom_order
:
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
...
@@ -132,6 +141,57 @@ def mol_to_bigraph(mol, add_self_loop=False,
...
@@ -132,6 +141,57 @@ def mol_to_bigraph(mol, add_self_loop=False,
-------
-------
g : DGLGraph
g : DGLGraph
Bi-directed DGLGraph for the molecule
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol, node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
smiles_to_bigraph
"""
"""
return
mol_to_graph
(
mol
,
partial
(
construct_bigraph_from_mol
,
add_self_loop
=
add_self_loop
),
return
mol_to_graph
(
mol
,
partial
(
construct_bigraph_from_mol
,
add_self_loop
=
add_self_loop
),
node_featurizer
,
edge_featurizer
,
canonical_atom_order
)
node_featurizer
,
edge_featurizer
,
canonical_atom_order
)
...
@@ -163,6 +223,54 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
...
@@ -163,6 +223,54 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
-------
-------
g : DGLGraph
g : DGLGraph
Bi-directed DGLGraph for the molecule
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_bigraph
>>> g = smiles_to_bigraph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> g = smiles_to_bigraph('CCO', node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
mol_to_bigraph
"""
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_bigraph
(
mol
,
add_self_loop
,
node_featurizer
,
return
mol_to_bigraph
(
mol
,
add_self_loop
,
node_featurizer
,
...
@@ -226,6 +334,66 @@ def mol_to_complete_graph(mol, add_self_loop=False,
...
@@ -226,6 +334,66 @@ def mol_to_complete_graph(mol, add_self_loop=False,
-------
-------
g : DGLGraph
g : DGLGraph
Complete DGLGraph for the molecule
Complete DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_complete_graph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> add_self_loop = True
>>> g = mol_to_complete_graph(
>>> mol, add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
smiles_to_complete_graph
"""
"""
return
mol_to_graph
(
mol
,
return
mol_to_graph
(
mol
,
partial
(
construct_complete_graph_from_mol
,
add_self_loop
=
add_self_loop
),
partial
(
construct_complete_graph_from_mol
,
add_self_loop
=
add_self_loop
),
...
@@ -258,6 +426,63 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
...
@@ -258,6 +426,63 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
-------
-------
g : DGLGraph
g : DGLGraph
Complete DGLGraph for the molecule
Complete DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_complete_graph
>>> g = smiles_to_complete_graph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> add_self_loop = True
>>> g = smiles_to_complete_graph(
>>> 'CCO', add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
mol_to_complete_graph
"""
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_complete_graph
(
mol
,
add_self_loop
,
node_featurizer
,
return
mol_to_complete_graph
(
mol
,
add_self_loop
,
node_featurizer
,
...
@@ -297,6 +522,31 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
...
@@ -297,6 +522,31 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
Destination nodes, corresponding to ``srcs``.
Destination nodes, corresponding to ``srcs``.
distances : list of float
distances : list of float
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
Examples
--------
>>> from dgllife.utils import get_mol_3d_coordinates, k_nearest_neighbors
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> srcs, dsts, dists = k_nearest_neighbors(coords, neighbor_cutoff=1.25)
>>> print(srcs)
[8, 7, 11, 10, 20, 19]
>>> print(dsts)
[7, 8, 10, 11, 19, 20]
>>> print(dists)
[1.2084666104583117, 1.2084666104583117, 1.226457824344217,
1.226457824344217, 1.2230522248065987, 1.2230522248065987]
See Also
--------
get_mol_3d_coordinates
mol_to_nearest_neighbor_graph
smiles_to_nearest_neighbor_graph
"""
"""
num_atoms
=
coordinates
.
shape
[
0
]
num_atoms
=
coordinates
.
shape
[
0
]
model
=
NearestNeighbors
(
radius
=
neighbor_cutoff
,
p
=
p_distance
)
model
=
NearestNeighbors
(
radius
=
neighbor_cutoff
,
p
=
p_distance
)
...
@@ -378,6 +628,45 @@ def mol_to_nearest_neighbor_graph(mol,
...
@@ -378,6 +628,45 @@ def mol_to_nearest_neighbor_graph(mol,
dist_field : str
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import mol_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
smiles_to_nearest_neighbor_graph
"""
"""
if
canonical_atom_order
:
if
canonical_atom_order
:
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
...
@@ -463,6 +752,46 @@ def smiles_to_nearest_neighbor_graph(smiles,
...
@@ -463,6 +752,46 @@ def smiles_to_nearest_neighbor_graph(smiles,
dist_field : str
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> smiles = 'CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C'
>>> mol = Chem.MolFromSmiles(smiles)
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = smiles_to_nearest_neighbor_graph(smiles, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
mol_to_nearest_neighbor_graph
"""
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_nearest_neighbor_graph
(
return
mol_to_nearest_neighbor_graph
(
...
...
apps/life_sci/python/dgllife/utils/rdkit_utils.py
View file @
d3560b71
...
@@ -16,6 +16,8 @@ __all__ = ['get_mol_3d_coordinates',
...
@@ -16,6 +16,8 @@ __all__ = ['get_mol_3d_coordinates',
def
get_mol_3d_coordinates
(
mol
):
def
get_mol_3d_coordinates
(
mol
):
"""Get 3D coordinates of the molecule.
"""Get 3D coordinates of the molecule.
This function requires that molecular conformation has been initialized.
Parameters
Parameters
----------
----------
mol : rdkit.Chem.rdchem.Mol
mol : rdkit.Chem.rdchem.Mol
...
@@ -26,6 +28,27 @@ def get_mol_3d_coordinates(mol):
...
@@ -26,6 +28,27 @@ def get_mol_3d_coordinates(mol):
numpy.ndarray of shape (N, 3) or None
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
the molecule. For failures in getting the conformations, None will be returned.
Examples
--------
An error will occur in the example below since the molecule object does not
carry conformation information.
>>> from rdkit import Chem
>>> from dgllife.utils import get_mol_3d_coordinates
>>> mol = Chem.MolFromSmiles('CCO')
Below we give a working example based on molecule conformation initialized from calculation.
>>> from rdkit.Chem import AllChem
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> print(coords)
array([[ 1.20967478, -0.25802181, 0. ],
[-0.05021255, 0.57068079, 0. ],
[-1.15946223, -0.31265898, 0. ]])
"""
"""
try
:
try
:
conf
=
mol
.
GetConformer
()
conf
=
mol
.
GetConformer
()
...
@@ -42,13 +65,13 @@ def get_mol_3d_coordinates(mol):
...
@@ -42,13 +65,13 @@ def get_mol_3d_coordinates(mol):
# pylint: disable=E1101
# pylint: disable=E1101
def
load_molecule
(
molecule_file
,
sanitize
=
False
,
calc_charges
=
False
,
def
load_molecule
(
molecule_file
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
):
remove_hs
=
False
,
use_conformation
=
True
):
"""Load a molecule from a file.
"""Load a molecule from a file
of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``
.
Parameters
Parameters
----------
----------
molecule_file : str
molecule_file : str
Path to file for storing a molecule, which can be of format
'
.mol2
', '
.sdf
',
Path to file for storing a molecule, which can be of format
``
.mol2
`` or ``
.sdf
``
'
.pdbqt
',
or
'
.pdb
'
.
or ``
.pdbqt
``
or
``
.pdb
``
.
sanitize : bool
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
...
@@ -115,13 +138,14 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
...
@@ -115,13 +138,14 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
def
multiprocess_load_molecules
(
files
,
sanitize
=
False
,
calc_charges
=
False
,
def
multiprocess_load_molecules
(
files
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
,
num_processes
=
2
):
remove_hs
=
False
,
use_conformation
=
True
,
num_processes
=
2
):
"""Load molecules from files with multiprocessing.
"""Load molecules from files with multiprocessing, which can be of format ``.mol2`` or
``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
Parameters
----------
----------
files : list of str
files : list of str
Each element is a path to a file storing a molecule, which can be of format
'
.mol2
'
,
Each element is a path to a file storing a molecule, which can be of format
``
.mol2
``
,
'
.sdf
', '
.pdbqt
'
, or
'
.pdb
'
.
``
.sdf
``, ``
.pdbqt
``
, or
``
.pdb
``
.
sanitize : bool
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
...
...
apps/life_sci/python/dgllife/utils/splitters.py
View file @
d3560b71
...
@@ -42,7 +42,8 @@ def base_k_fold_split(split_method, dataset, k, log):
...
@@ -42,7 +42,8 @@ def base_k_fold_split(split_method, dataset, k, log):
Returns
Returns
-------
-------
all_folds : list of 2-tuples
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple (train_set, val_set),
which are all :class:`Subset` instances.
"""
"""
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
all_folds
=
[]
all_folds
=
[]
...
@@ -208,7 +209,8 @@ class ConsecutiveSplitter(object):
...
@@ -208,7 +209,8 @@ class ConsecutiveSplitter(object):
Returns
Returns
-------
-------
list of length 3
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test that also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
"""
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
shuffle
=
False
)
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
shuffle
=
False
)
...
@@ -229,7 +231,8 @@ class ConsecutiveSplitter(object):
...
@@ -229,7 +231,8 @@ class ConsecutiveSplitter(object):
Returns
Returns
-------
-------
list of 2-tuples
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
"""
return
base_k_fold_split
(
ConsecutiveSplitter
.
train_val_test_split
,
dataset
,
k
,
log
)
return
base_k_fold_split
(
ConsecutiveSplitter
.
train_val_test_split
,
dataset
,
k
,
log
)
...
@@ -269,7 +272,8 @@ class RandomSplitter(object):
...
@@ -269,7 +272,8 @@ class RandomSplitter(object):
Returns
Returns
-------
-------
list of length 3
list of length 3
Subsets for training, validation and test.
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors.
"""
"""
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
shuffle
=
True
,
random_state
=
random_state
)
shuffle
=
True
,
random_state
=
random_state
)
...
@@ -298,7 +302,8 @@ class RandomSplitter(object):
...
@@ -298,7 +302,8 @@ class RandomSplitter(object):
Returns
Returns
-------
-------
list of 2-tuples
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
"""
# Permute the dataset only once so that each datapoint
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
# will appear once in exactly one fold.
...
@@ -381,7 +386,8 @@ class MolecularWeightSplitter(object):
...
@@ -381,7 +386,8 @@ class MolecularWeightSplitter(object):
Returns
Returns
-------
-------
list of length 3
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
"""
# Perform sanity check first as molecule instance initialization and descriptor
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
# computation can take a long time.
...
@@ -422,7 +428,8 @@ class MolecularWeightSplitter(object):
...
@@ -422,7 +428,8 @@ class MolecularWeightSplitter(object):
Returns
Returns
-------
-------
list of 2-tuples
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
"""
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
)
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
)
sorted_indices
=
MolecularWeightSplitter
.
molecular_weight_indices
(
molecules
,
log_every_n
)
sorted_indices
=
MolecularWeightSplitter
.
molecular_weight_indices
(
molecules
,
log_every_n
)
...
@@ -543,7 +550,8 @@ class ScaffoldSplitter(object):
...
@@ -543,7 +550,8 @@ class ScaffoldSplitter(object):
Returns
Returns
-------
-------
list of length 3
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test, which also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
"""
# Perform sanity check first as molecule related computation can take a long time.
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
...
@@ -609,7 +617,8 @@ class ScaffoldSplitter(object):
...
@@ -609,7 +617,8 @@ class ScaffoldSplitter(object):
Returns
Returns
-------
-------
list of 2-tuples
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
"""
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
...
@@ -677,7 +686,8 @@ class SingleTaskStratifiedSplitter(object):
...
@@ -677,7 +686,8 @@ class SingleTaskStratifiedSplitter(object):
Returns
Returns
-------
-------
list of length 3
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
"""
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
...
@@ -735,7 +745,8 @@ class SingleTaskStratifiedSplitter(object):
...
@@ -735,7 +745,8 @@ class SingleTaskStratifiedSplitter(object):
Returns
Returns
-------
-------
list of 2-tuples
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
"""
if
not
isinstance
(
labels
,
np
.
ndarray
):
if
not
isinstance
(
labels
,
np
.
ndarray
):
labels
=
F
.
asnumpy
(
labels
)
labels
=
F
.
asnumpy
(
labels
)
...
...
docs/source/install/index.rst
View file @
d3560b71
...
@@ -12,7 +12,7 @@ DGL works with the following operating systems:
...
@@ -12,7 +12,7 @@ DGL works with the following operating systems:
* Windows 10
* Windows 10
DGL requires Python version 3.5 or later. Python 3.4 or earlier is not
DGL requires Python version 3.5 or later. Python 3.4 or earlier is not
tested.
Python 2 support is coming.
tested.
DGL supports multiple tensor libraries as backends, e.g., PyTorch, MXNet. For requirements on backends and how to select one, see :ref:`backends`.
DGL supports multiple tensor libraries as backends, e.g., PyTorch, MXNet. For requirements on backends and how to select one, see :ref:`backends`.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment