Unverified Commit d3560b71 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Documentation (#1414)

* Update

* Update

* Update
parent 7e0893e6
...@@ -21,7 +21,6 @@ class Tox21(MoleculeCSVDataset): ...@@ -21,7 +21,6 @@ class Tox21(MoleculeCSVDataset):
A common issue for multi-task prediction is that some datapoints are not labeled for A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation. labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction, All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime. the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
...@@ -87,9 +86,18 @@ class Tox21(MoleculeCSVDataset): ...@@ -87,9 +86,18 @@ class Tox21(MoleculeCSVDataset):
def task_pos_weights(self): def task_pos_weights(self):
"""Get weights for positive samples on each task """Get weights for positive samples on each task
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
Returns Returns
------- -------
numpy.ndarray Tensor of dtype float32 and shape (T)
numpy array gives the weight of positive samples on all tasks Weight of positive samples on all tasks
""" """
return self._task_pos_weights return self._task_pos_weights
...@@ -363,15 +363,15 @@ class WLNReactionDataset(object): ...@@ -363,15 +363,15 @@ class WLNReactionDataset(object):
Returns Returns
------- -------
str str
Reaction Reaction.
str str
Graph edits for the reaction Graph edits for the reaction
rdkit.Chem.rdchem.Mol rdkit.Chem.rdchem.Mol
RDKit molecule instance RDKit molecule instance for reactants
DGLGraph DGLGraph
DGLGraph for the ith molecular graph DGLGraph for the ith molecular graph of reactants
DGLGraph DGLGraph
Complete DGLGraph, which will be needed for predicting Complete DGLGraph for reactants, which will be needed for predicting
scores between each pair of atoms scores between each pair of atoms
float32 tensor of shape (V^2, 10) float32 tensor of shape (V^2, 10)
Features for each pair of atoms. Features for each pair of atoms.
...@@ -477,7 +477,6 @@ class USPTO(WLNReactionDataset): ...@@ -477,7 +477,6 @@ class USPTO(WLNReactionDataset):
------- -------
str str
* 'full' for the complete dataset
* 'train' for the training set * 'train' for the training set
* 'val' for the validation set * 'val' for the validation set
* 'test' for the test set * 'test' for the test set
......
...@@ -68,15 +68,20 @@ def load_pretrained(model_name, log=True): ...@@ -68,15 +68,20 @@ def load_pretrained(model_name, log=True):
model_name : str model_name : str
Currently supported options include Currently supported options include
* ``'GCN_Tox21'`` * ``'GCN_Tox21'``: A GCN-based model for molecular property prediction on Tox21
* ``'GAT_Tox21'`` * ``'GAT_Tox21'``: A GAT-based model for molecular property prediction on Tox21
* ``'AttentiveFP_Aromaticity'`` * ``'AttentiveFP_Aromaticity'``: An AttentiveFP model for predicting number of
* ``'DGMG_ChEMBL_canonical'`` aromatic atoms on a subset of Pubmed
* ``'DGMG_ChEMBL_random'`` * ``'DGMG_ChEMBL_canonical'``: A DGMG model trained on ChEMBL with a canonical
* ``'DGMG_ZINC_canonical'`` atom order
* ``'DGMG_ZINC_random'`` * ``'DGMG_ChEMBL_random'``: A DGMG model trained on ChEMBL for molecule generation
* ``'JTNN_ZINC'`` with a random atom order
* ``'wln_center_uspto'`` * ``'DGMG_ZINC_canonical'``: A DGMG model trained on ZINC for molecule generation
with a canonical atom order
* ``'DGMG_ZINC_random'``: A DGMG model pre-trained on ZINC for molecule generation
with a random atom order
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
log : bool log : bool
Whether to print progress for model loading Whether to print progress for model loading
......
...@@ -22,7 +22,40 @@ class EarlyStopping(object): ...@@ -22,7 +22,40 @@ class EarlyStopping(object):
The early stopping will happen if we do not observe performance The early stopping will happen if we do not observe performance
improvement for ``patience`` consecutive epochs. improvement for ``patience`` consecutive epochs.
filename : str or None filename : str or None
Filename for storing the model checkpoint Filename for storing the model checkpoint. If not specified,
we will automatically generate a file starting with ``early_stop``
based on the current time.
Examples
--------
Below gives a demo for a fake training process.
>>> import torch
>>> import torch.nn as nn
>>> from torch.nn import MSELoss
>>> from torch.optim import Adam
>>> from dgllife.utils import EarlyStopping
>>> model = nn.Linear(1, 1)
>>> criterion = MSELoss()
>>> # For MSE, the lower, the better
>>> stopper = EarlyStopping(mode='lower', filename='test.pth')
>>> optimizer = Adam(params=model.parameters(), lr=1e-3)
>>> for epoch in range(1000):
>>> x = torch.randn(1, 1) # Fake input
>>> y = torch.randn(1, 1) # Fake label
>>> pred = model(x)
>>> loss = criterion(y, pred)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()
>>> early_stop = stopper.step(loss.detach().data, model)
>>> if early_stop:
>>> break
>>> # Load the final parameters saved by the model
>>> stopper.load_checkpoint(model)
""" """
def __init__(self, mode='higher', patience=10, filename=None): def __init__(self, mode='higher', patience=10, filename=None):
if filename is None: if filename is None:
......
...@@ -19,18 +19,44 @@ class Meter(object): ...@@ -19,18 +19,44 @@ class Meter(object):
Currently we support evaluation with 4 metrics: Currently we support evaluation with 4 metrics:
* pearson r2 * ``pearson r2``
* mae * ``mae``
* rmse * ``rmse``
* roc auc score * ``roc auc score``
Parameters Parameters
---------- ----------
mean : torch.float32 tensor of shape (T) or None. mean : torch.float32 tensor of shape (T) or None.
Mean of existing training labels across tasks if not None. T for the number of tasks. Mean of existing training labels across tasks if not ``None``. ``T`` for the
Default to None. number of tasks. Default to ``None`` and we assume no label normalization has been
performed.
std : torch.float32 tensor of shape (T) std : torch.float32 tensor of shape (T)
Std of existing training labels across tasks if not None. Std of existing training labels across tasks if not ``None``. Default to ``None``
and we assume no label normalization has been performed.
Examples
--------
Below gives a demo for a fake evaluation epoch.
>>> import torch
>>> from dgllife.utils import Meter
>>> meter = Meter()
>>> # Simulate 10 fake mini-batches
>>> for batch_id in range(10):
>>> batch_label = torch.randn(3, 3)
>>> batch_pred = torch.randn(3, 3)
>>> meter.update(batch_pred, batch_label)
>>> # Get MAE for all tasks
>>> print(meter.compute_metric('mae'))
[1.1325558423995972, 1.0543707609176636, 1.094650149345398]
>>> # Get MAE averaged over all tasks
>>> print(meter.compute_metric('mae', reduction='mean'))
1.0938589175542195
>>> # Get the sum of MAE over all tasks
>>> print(meter.compute_metric('mae', reduction='sum'))
3.2815767526626587
""" """
def __init__(self, mean=None, std=None): def __init__(self, mean=None, std=None):
self.mask = [] self.mask = []
...@@ -50,13 +76,13 @@ class Meter(object): ...@@ -50,13 +76,13 @@ class Meter(object):
Parameters Parameters
---------- ----------
y_pred : float32 tensor y_pred : float32 tensor
Predicted labels with shape (B, T), Predicted labels with shape ``(B, T)``,
B for number of graphs in the batch and T for the number of tasks ``B`` for number of graphs in the batch and ``T`` for the number of tasks
y_true : float32 tensor y_true : float32 tensor
Ground truth labels with shape (B, T) Ground truth labels with shape ``(B, T)``
mask : None or float32 tensor mask : None or float32 tensor
Binary mask indicating the existence of ground truth labels with Binary mask indicating the existence of ground truth labels with
shape (B, T). If None, we assume that all labels exist and create shape ``(B, T)``. If None, we assume that all labels exist and create
a one-tensor for placeholder. a one-tensor for placeholder.
""" """
self.y_pred.append(y_pred.detach().cpu()) self.y_pred.append(y_pred.detach().cpu())
...@@ -237,10 +263,10 @@ class Meter(object): ...@@ -237,10 +263,10 @@ class Meter(object):
---------- ----------
metric_name : str metric_name : str
* 'r2': compute squared Pearson correlation coefficient * ``'r2'``: compute squared Pearson correlation coefficient
* 'mae': compute mean absolute error * ``'mae'``: compute mean absolute error
* 'rmse': compute root mean square error * ``'rmse'``: compute root mean square error
* 'roc_auc_score': compute roc-auc score * ``'roc_auc_score'``: compute roc-auc score
reduction : 'none' or 'mean' or 'sum' reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks Controls the form of scores for all tasks
......
...@@ -64,6 +64,16 @@ def one_hot_encoding(x, allowable_set, encode_unknown=False): ...@@ -64,6 +64,16 @@ def one_hot_encoding(x, allowable_set, encode_unknown=False):
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False`` The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise. and ``len(allowable_set) + 1`` otherwise.
Examples
--------
>>> from dgllife.utils import one_hot_encoding
>>> one_hot_encoding('C', ['C', 'O'])
[True, False]
>>> one_hot_encoding('S', ['C', 'O'])
[False, False]
>>> one_hot_encoding('S', ['C', 'O'], encode_unknown=True)
[False, False, True]
""" """
if encode_unknown and (allowable_set[-1] is not None): if encode_unknown and (allowable_set[-1] is not None):
allowable_set.append(None) allowable_set.append(None)
...@@ -98,6 +108,12 @@ def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -98,6 +108,12 @@ def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atomic_number_one_hot
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', allowable_set = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
...@@ -123,6 +139,12 @@ def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -123,6 +139,12 @@ def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atom_type_one_hot
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = list(range(1, 101)) allowable_set = list(range(1, 101))
...@@ -140,6 +162,11 @@ def atomic_number(atom): ...@@ -140,6 +162,11 @@ def atomic_number(atom):
------- -------
list list
List containing one int only. List containing one int only.
See Also
--------
atomic_number_one_hot
atom_type_one_hot
""" """
return [atom.GetAtomicNum()] return [atom.GetAtomicNum()]
...@@ -166,6 +193,9 @@ def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -166,6 +193,9 @@ def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
See Also See Also
-------- --------
one_hot_encoding
atom_degree
atom_total_degree
atom_total_degree_one_hot atom_total_degree_one_hot
""" """
if allowable_set is None: if allowable_set is None:
...@@ -190,7 +220,9 @@ def atom_degree(atom): ...@@ -190,7 +220,9 @@ def atom_degree(atom):
See Also See Also
-------- --------
atom_degree_one_hot
atom_total_degree atom_total_degree
atom_total_degree_one_hot
""" """
return [atom.GetDegree()] return [atom.GetDegree()]
...@@ -209,7 +241,10 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -209,7 +241,10 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
See Also See Also
-------- --------
one_hot_encoding
atom_degree
atom_degree_one_hot atom_degree_one_hot
atom_total_degree
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = list(range(6)) allowable_set = list(range(6))
...@@ -218,14 +253,16 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -218,14 +253,16 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
def atom_total_degree(atom): def atom_total_degree(atom):
"""The degree of an atom including Hs. """The degree of an atom including Hs.
See Also
--------
atom_degree
Returns Returns
------- -------
list list
List containing one int only. List containing one int only.
See Also
--------
atom_total_degree_one_hot
atom_degree
atom_degree_one_hot
""" """
return [atom.GetTotalDegree()] return [atom.GetTotalDegree()]
...@@ -246,6 +283,11 @@ def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False ...@@ -246,6 +283,11 @@ def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_explicit_valence
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = list(range(1, 7)) allowable_set = list(range(1, 7))
...@@ -263,6 +305,10 @@ def atom_explicit_valence(atom): ...@@ -263,6 +305,10 @@ def atom_explicit_valence(atom):
------- -------
list list
List containing one int only. List containing one int only.
See Also
--------
atom_explicit_valence_one_hot
""" """
return [atom.GetExplicitValence()] return [atom.GetExplicitValence()]
...@@ -283,6 +329,10 @@ def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False ...@@ -283,6 +329,10 @@ def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
atom_implicit_valence
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = list(range(7)) allowable_set = list(range(7))
...@@ -300,6 +350,10 @@ def atom_implicit_valence(atom): ...@@ -300,6 +350,10 @@ def atom_implicit_valence(atom):
------ ------
list list
List containing one int only. List containing one int only.
See Also
--------
atom_implicit_valence_one_hot
""" """
return [atom.GetImplicitValence()] return [atom.GetImplicitValence()]
...@@ -323,6 +377,10 @@ def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -323,6 +377,10 @@ def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [Chem.rdchem.HybridizationType.SP, allowable_set = [Chem.rdchem.HybridizationType.SP,
...@@ -349,6 +407,11 @@ def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -349,6 +407,11 @@ def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_total_num_H
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = list(range(5)) allowable_set = list(range(5))
...@@ -366,6 +429,10 @@ def atom_total_num_H(atom): ...@@ -366,6 +429,10 @@ def atom_total_num_H(atom):
------- -------
list list
List containing one int only. List containing one int only.
See Also
--------
atom_total_num_H_one_hot
""" """
return [atom.GetTotalNumHs()] return [atom.GetTotalNumHs()]
...@@ -386,6 +453,11 @@ def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -386,6 +453,11 @@ def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_formal_charge
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = list(range(-2, 3)) allowable_set = list(range(-2, 3))
...@@ -403,6 +475,10 @@ def atom_formal_charge(atom): ...@@ -403,6 +475,10 @@ def atom_formal_charge(atom):
------- -------
list list
List containing one int only. List containing one int only.
See Also
--------
atom_formal_charge_one_hot
""" """
return [atom.GetFormalCharge()] return [atom.GetFormalCharge()]
...@@ -423,6 +499,11 @@ def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown= ...@@ -423,6 +499,11 @@ def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_num_radical_electrons
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = list(range(5)) allowable_set = list(range(5))
...@@ -440,6 +521,10 @@ def atom_num_radical_electrons(atom): ...@@ -440,6 +521,10 @@ def atom_num_radical_electrons(atom):
------- -------
list list
List containing one int only. List containing one int only.
See Also
--------
atom_num_radical_electrons_one_hot
""" """
return [atom.GetNumRadicalElectrons()] return [atom.GetNumRadicalElectrons()]
...@@ -460,6 +545,11 @@ def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -460,6 +545,11 @@ def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_aromatic
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [False, True] allowable_set = [False, True]
...@@ -477,6 +567,10 @@ def atom_is_aromatic(atom): ...@@ -477,6 +567,10 @@ def atom_is_aromatic(atom):
------- -------
list list
List containing one bool only. List containing one bool only.
See Also
--------
atom_is_aromatic_one_hot
""" """
return [atom.GetIsAromatic()] return [atom.GetIsAromatic()]
...@@ -497,6 +591,11 @@ def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -497,6 +591,11 @@ def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_in_ring
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [False, True] allowable_set = [False, True]
...@@ -514,6 +613,10 @@ def atom_is_in_ring(atom): ...@@ -514,6 +613,10 @@ def atom_is_in_ring(atom):
------- -------
list list
List containing one bool only. List containing one bool only.
See Also
--------
atom_is_in_ring_one_hot
""" """
return [atom.IsInRing()] return [atom.IsInRing()]
...@@ -529,6 +632,10 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -529,6 +632,10 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``, ``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``, ``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``. ``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
See Also
--------
one_hot_encoding
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED, allowable_set = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
...@@ -589,7 +696,8 @@ class BaseAtomFeaturizer(object): ...@@ -589,7 +696,8 @@ class BaseAtomFeaturizer(object):
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``. Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.** **We assume the resulting DGLGraph will not contain any virtual nodes and a node i in the
graph corresponds to exactly atom i in the molecule.**
Parameters Parameters
---------- ----------
...@@ -603,7 +711,7 @@ class BaseAtomFeaturizer(object): ...@@ -603,7 +711,7 @@ class BaseAtomFeaturizer(object):
Examples Examples
-------- --------
>>> from dgl.data.dgllife import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot >>> from dgllife.utils import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem >>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO') >>> mol = Chem.MolFromSmiles('CCO')
...@@ -615,6 +723,16 @@ class BaseAtomFeaturizer(object): ...@@ -615,6 +723,16 @@ class BaseAtomFeaturizer(object):
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])} [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
>>> # Get feature size for atom mass
>>> print(atom_featurizer.feat_size('mass'))
1
>>> # Get feature size for atom degree
>>> print(atom_featurizer.feat_size('degree'))
11
See Also
--------
CanonicalAtomFeaturizer
""" """
def __init__(self, featurizer_funcs, feat_sizes=None): def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs self.featurizer_funcs = featurizer_funcs
...@@ -701,6 +819,38 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer): ...@@ -701,6 +819,38 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
---------- ----------
atom_data_field : str atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'. Name for storing atom features in DGLGraphs, default to be 'h'.
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import CanonicalAtomFeaturizer
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')
>>> atom_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]])}
>>> # Get feature size for nodes
>>> print(atom_featurizer.feat_size('feat'))
74
See Also
--------
BaseAtomFeaturizer
""" """
def __init__(self, atom_data_field='h'): def __init__(self, atom_data_field='h'):
super(CanonicalAtomFeaturizer, self).__init__( super(CanonicalAtomFeaturizer, self).__init__(
...@@ -734,6 +884,10 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False): ...@@ -734,6 +884,10 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [Chem.rdchem.BondType.SINGLE, allowable_set = [Chem.rdchem.BondType.SINGLE,
...@@ -744,6 +898,7 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False): ...@@ -744,6 +898,7 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False): def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is conjugated. """One hot encoding for whether the bond is conjugated.
Parameters Parameters
---------- ----------
bond : rdkit.Chem.rdchem.Bond bond : rdkit.Chem.rdchem.Bond
...@@ -753,10 +908,16 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False): ...@@ -753,10 +908,16 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool encode_unknown : bool
If True, map inputs not in the allowable set to the If True, map inputs not in the allowable set to the
additional last element. (Default: False) additional last element. (Default: False)
Returns Returns
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_conjugated
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [False, True] allowable_set = [False, True]
...@@ -764,19 +925,26 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False): ...@@ -764,19 +925,26 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
def bond_is_conjugated(bond): def bond_is_conjugated(bond):
"""Get whether the bond is conjugated. """Get whether the bond is conjugated.
Parameters Parameters
---------- ----------
bond : rdkit.Chem.rdchem.Bond bond : rdkit.Chem.rdchem.Bond
RDKit bond instance. RDKit bond instance.
Returns Returns
------- -------
list list
List containing one bool only. List containing one bool only.
See Also
--------
bond_is_conjugated_one_hot
""" """
return [bond.GetIsConjugated()] return [bond.GetIsConjugated()]
def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False): def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is in a ring of any size. """One hot encoding for whether the bond is in a ring of any size.
Parameters Parameters
---------- ----------
bond : rdkit.Chem.rdchem.Bond bond : rdkit.Chem.rdchem.Bond
...@@ -786,10 +954,16 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False): ...@@ -786,10 +954,16 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool encode_unknown : bool
If True, map inputs not in the allowable set to the If True, map inputs not in the allowable set to the
additional last element. (Default: False) additional last element. (Default: False)
Returns Returns
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_in_ring
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [False, True] allowable_set = [False, True]
...@@ -797,19 +971,26 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False): ...@@ -797,19 +971,26 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
def bond_is_in_ring(bond): def bond_is_in_ring(bond):
"""Get whether the bond is in a ring of any size. """Get whether the bond is in a ring of any size.
Parameters Parameters
---------- ----------
bond : rdkit.Chem.rdchem.Bond bond : rdkit.Chem.rdchem.Bond
RDKit bond instance. RDKit bond instance.
Returns Returns
------- -------
list list
List containing one bool only. List containing one bool only.
See Also
--------
bond_is_in_ring_one_hot
""" """
return [bond.IsInRing()] return [bond.IsInRing()]
def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False): def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the stereo configuration of a bond. """One hot encoding for the stereo configuration of a bond.
Parameters Parameters
---------- ----------
bond : rdkit.Chem.rdchem.Bond bond : rdkit.Chem.rdchem.Bond
...@@ -822,10 +1003,15 @@ def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False): ...@@ -822,10 +1003,15 @@ def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool encode_unknown : bool
If True, map inputs not in the allowable set to the If True, map inputs not in the allowable set to the
additional last element. (Default: False) additional last element. (Default: False)
Returns Returns
------- -------
list list
List of boolean values where at most one value is True. List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
""" """
if allowable_set is None: if allowable_set is None:
allowable_set = [Chem.rdchem.BondStereo.STEREONONE, allowable_set = [Chem.rdchem.BondStereo.STEREONONE,
...@@ -858,7 +1044,7 @@ class BaseBondFeaturizer(object): ...@@ -858,7 +1044,7 @@ class BaseBondFeaturizer(object):
Examples Examples
-------- --------
>>> from dgl.data.dgllife import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring >>> from dgllife.utils import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem >>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO') >>> mol = Chem.MolFromSmiles('CCO')
...@@ -869,6 +1055,15 @@ class BaseBondFeaturizer(object): ...@@ -869,6 +1055,15 @@ class BaseBondFeaturizer(object):
[1., 0., 0., 0.], [1., 0., 0., 0.],
[1., 0., 0., 0.]]), [1., 0., 0., 0.]]),
'ring': tensor([[0.], [0.], [0.], [0.]])} 'ring': tensor([[0.], [0.], [0.], [0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
4
>>> bond_featurizer.feat_size('ring')
1
See Also
--------
CanonicalBondFeaturizer
""" """
def __init__(self, featurizer_funcs, feat_sizes=None): def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs self.featurizer_funcs = featurizer_funcs
...@@ -941,6 +1136,26 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer): ...@@ -941,6 +1136,26 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without **We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.** self loops.**
Examples
--------
>>> from dgllife.utils import CanonicalBondFeaturizer
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = CanonicalBondFeaturizer(bond_data_field='feat')
>>> bond_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
12
See Also
--------
BaseBondFeaturizer
""" """
def __init__(self, bond_data_field='e'): def __init__(self, bond_data_field='e'):
super(CanonicalBondFeaturizer, self).__init__( super(CanonicalBondFeaturizer, self).__init__(
......
...@@ -21,6 +21,9 @@ __all__ = ['mol_to_graph', ...@@ -21,6 +21,9 @@ __all__ = ['mol_to_graph',
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order): def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it. """Convert an RDKit molecule object into a DGLGraph and featurize for it.
This function can be used to construct any arbitrary ``DGLGraph`` from an
RDKit molecule instance.
Parameters Parameters
---------- ----------
mol : rdkit.Chem.rdchem.Mol mol : rdkit.Chem.rdchem.Mol
...@@ -41,6 +44,12 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canon ...@@ -41,6 +44,12 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canon
------- -------
g : DGLGraph g : DGLGraph
Converted DGLGraph for the molecule Converted DGLGraph for the molecule
See Also
--------
mol_to_bigraph
mol_to_complete_graph
mol_to_nearest_neighbor_graph
""" """
if canonical_atom_order: if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol) new_order = rdmolfiles.CanonicalRankAtoms(mol)
...@@ -132,6 +141,57 @@ def mol_to_bigraph(mol, add_self_loop=False, ...@@ -132,6 +141,57 @@ def mol_to_bigraph(mol, add_self_loop=False,
------- -------
g : DGLGraph g : DGLGraph
Bi-directed DGLGraph for the molecule Bi-directed DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol, node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
smiles_to_bigraph
""" """
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop), return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer, canonical_atom_order) node_featurizer, edge_featurizer, canonical_atom_order)
...@@ -163,6 +223,54 @@ def smiles_to_bigraph(smiles, add_self_loop=False, ...@@ -163,6 +223,54 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
------- -------
g : DGLGraph g : DGLGraph
Bi-directed DGLGraph for the molecule Bi-directed DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_bigraph
>>> g = smiles_to_bigraph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> g = smiles_to_bigraph('CCO', node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
mol_to_bigraph
""" """
mol = Chem.MolFromSmiles(smiles) mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer, return mol_to_bigraph(mol, add_self_loop, node_featurizer,
...@@ -226,6 +334,66 @@ def mol_to_complete_graph(mol, add_self_loop=False, ...@@ -226,6 +334,66 @@ def mol_to_complete_graph(mol, add_self_loop=False,
------- -------
g : DGLGraph g : DGLGraph
Complete DGLGraph for the molecule Complete DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_complete_graph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> add_self_loop = True
>>> g = mol_to_complete_graph(
>>> mol, add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
smiles_to_complete_graph
""" """
return mol_to_graph(mol, return mol_to_graph(mol,
partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop), partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
...@@ -258,6 +426,63 @@ def smiles_to_complete_graph(smiles, add_self_loop=False, ...@@ -258,6 +426,63 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
------- -------
g : DGLGraph g : DGLGraph
Complete DGLGraph for the molecule Complete DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_complete_graph
>>> g = smiles_to_complete_graph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> add_self_loop = True
>>> g = smiles_to_complete_graph(
>>> 'CCO', add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
mol_to_complete_graph
""" """
mol = Chem.MolFromSmiles(smiles) mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, return mol_to_complete_graph(mol, add_self_loop, node_featurizer,
...@@ -297,6 +522,31 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None, ...@@ -297,6 +522,31 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
Destination nodes, corresponding to ``srcs``. Destination nodes, corresponding to ``srcs``.
distances : list of float distances : list of float
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``. Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
Examples
--------
>>> from dgllife.utils import get_mol_3d_coordinates, k_nearest_neighbors
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> srcs, dsts, dists = k_nearest_neighbors(coords, neighbor_cutoff=1.25)
>>> print(srcs)
[8, 7, 11, 10, 20, 19]
>>> print(dsts)
[7, 8, 10, 11, 19, 20]
>>> print(dists)
[1.2084666104583117, 1.2084666104583117, 1.226457824344217,
1.226457824344217, 1.2230522248065987, 1.2230522248065987]
See Also
--------
get_mol_3d_coordinates
mol_to_nearest_neighbor_graph
smiles_to_nearest_neighbor_graph
""" """
num_atoms = coordinates.shape[0] num_atoms = coordinates.shape[0]
model = NearestNeighbors(radius=neighbor_cutoff, p=p_distance) model = NearestNeighbors(radius=neighbor_cutoff, p=p_distance)
...@@ -378,6 +628,45 @@ def mol_to_nearest_neighbor_graph(mol, ...@@ -378,6 +628,45 @@ def mol_to_nearest_neighbor_graph(mol,
dist_field : str dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``. into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import mol_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
smiles_to_nearest_neighbor_graph
""" """
if canonical_atom_order: if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol) new_order = rdmolfiles.CanonicalRankAtoms(mol)
...@@ -463,6 +752,46 @@ def smiles_to_nearest_neighbor_graph(smiles, ...@@ -463,6 +752,46 @@ def smiles_to_nearest_neighbor_graph(smiles,
dist_field : str dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``. into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> smiles = 'CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C'
>>> mol = Chem.MolFromSmiles(smiles)
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = smiles_to_nearest_neighbor_graph(smiles, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
mol_to_nearest_neighbor_graph
""" """
mol = Chem.MolFromSmiles(smiles) mol = Chem.MolFromSmiles(smiles)
return mol_to_nearest_neighbor_graph( return mol_to_nearest_neighbor_graph(
......
...@@ -16,6 +16,8 @@ __all__ = ['get_mol_3d_coordinates', ...@@ -16,6 +16,8 @@ __all__ = ['get_mol_3d_coordinates',
def get_mol_3d_coordinates(mol): def get_mol_3d_coordinates(mol):
"""Get 3D coordinates of the molecule. """Get 3D coordinates of the molecule.
This function requires that molecular conformation has been initialized.
Parameters Parameters
---------- ----------
mol : rdkit.Chem.rdchem.Mol mol : rdkit.Chem.rdchem.Mol
...@@ -26,6 +28,27 @@ def get_mol_3d_coordinates(mol): ...@@ -26,6 +28,27 @@ def get_mol_3d_coordinates(mol):
numpy.ndarray of shape (N, 3) or None numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned. the molecule. For failures in getting the conformations, None will be returned.
Examples
--------
An error will occur in the example below since the molecule object does not
carry conformation information.
>>> from rdkit import Chem
>>> from dgllife.utils import get_mol_3d_coordinates
>>> mol = Chem.MolFromSmiles('CCO')
Below we give a working example based on molecule conformation initialized from calculation.
>>> from rdkit.Chem import AllChem
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> print(coords)
array([[ 1.20967478, -0.25802181, 0. ],
[-0.05021255, 0.57068079, 0. ],
[-1.15946223, -0.31265898, 0. ]])
""" """
try: try:
conf = mol.GetConformer() conf = mol.GetConformer()
...@@ -42,13 +65,13 @@ def get_mol_3d_coordinates(mol): ...@@ -42,13 +65,13 @@ def get_mol_3d_coordinates(mol):
# pylint: disable=E1101 # pylint: disable=E1101
def load_molecule(molecule_file, sanitize=False, calc_charges=False, def load_molecule(molecule_file, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True): remove_hs=False, use_conformation=True):
"""Load a molecule from a file. """Load a molecule from a file of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters Parameters
---------- ----------
molecule_file : str molecule_file : str
Path to file for storing a molecule, which can be of format '.mol2', '.sdf', Path to file for storing a molecule, which can be of format ``.mol2`` or ``.sdf``
'.pdbqt', or '.pdb'. or ``.pdbqt`` or ``.pdb``.
sanitize : bool sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization. https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
...@@ -115,13 +138,14 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False, ...@@ -115,13 +138,14 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
def multiprocess_load_molecules(files, sanitize=False, calc_charges=False, def multiprocess_load_molecules(files, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2): remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing. """Load molecules from files with multiprocessing, which can be of format ``.mol2`` or
``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters Parameters
---------- ----------
files : list of str files : list of str
Each element is a path to a file storing a molecule, which can be of format '.mol2', Each element is a path to a file storing a molecule, which can be of format ``.mol2``,
'.sdf', '.pdbqt', or '.pdb'. ``.sdf``, ``.pdbqt``, or ``.pdb``.
sanitize : bool sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization. https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
......
...@@ -42,7 +42,8 @@ def base_k_fold_split(split_method, dataset, k, log): ...@@ -42,7 +42,8 @@ def base_k_fold_split(split_method, dataset, k, log):
Returns Returns
------- -------
all_folds : list of 2-tuples all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set). Each element of the list represents a fold and is a 2-tuple (train_set, val_set),
which are all :class:`Subset` instances.
""" """
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k) assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
all_folds = [] all_folds = []
...@@ -208,7 +209,8 @@ class ConsecutiveSplitter(object): ...@@ -208,7 +209,8 @@ class ConsecutiveSplitter(object):
Returns Returns
------- -------
list of length 3 list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances. Subsets for training, validation and test that also have ``len(dataset)`` and
``dataset[i]`` behaviors
""" """
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False) return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
...@@ -229,7 +231,8 @@ class ConsecutiveSplitter(object): ...@@ -229,7 +231,8 @@ class ConsecutiveSplitter(object):
Returns Returns
------- -------
list of 2-tuples list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set). Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
""" """
return base_k_fold_split(ConsecutiveSplitter.train_val_test_split, dataset, k, log) return base_k_fold_split(ConsecutiveSplitter.train_val_test_split, dataset, k, log)
...@@ -269,7 +272,8 @@ class RandomSplitter(object): ...@@ -269,7 +272,8 @@ class RandomSplitter(object):
Returns Returns
------- -------
list of length 3 list of length 3
Subsets for training, validation and test. Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors.
""" """
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test],
shuffle=True, random_state=random_state) shuffle=True, random_state=random_state)
...@@ -298,7 +302,8 @@ class RandomSplitter(object): ...@@ -298,7 +302,8 @@ class RandomSplitter(object):
Returns Returns
------- -------
list of 2-tuples list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set). Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
""" """
# Permute the dataset only once so that each datapoint # Permute the dataset only once so that each datapoint
# will appear once in exactly one fold. # will appear once in exactly one fold.
...@@ -381,7 +386,8 @@ class MolecularWeightSplitter(object): ...@@ -381,7 +386,8 @@ class MolecularWeightSplitter(object):
Returns Returns
------- -------
list of length 3 list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances. Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
""" """
# Perform sanity check first as molecule instance initialization and descriptor # Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time. # computation can take a long time.
...@@ -422,7 +428,8 @@ class MolecularWeightSplitter(object): ...@@ -422,7 +428,8 @@ class MolecularWeightSplitter(object):
Returns Returns
------- -------
list of 2-tuples list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set). Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
""" """
molecules = prepare_mols(dataset, mols, sanitize, log_every_n) molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n) sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
...@@ -543,7 +550,8 @@ class ScaffoldSplitter(object): ...@@ -543,7 +550,8 @@ class ScaffoldSplitter(object):
Returns Returns
------- -------
list of length 3 list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances. Subsets for training, validation and test, which also have ``len(dataset)`` and
``dataset[i]`` behaviors
""" """
# Perform sanity check first as molecule related computation can take a long time. # Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test) train_val_test_sanity_check(frac_train, frac_val, frac_test)
...@@ -609,7 +617,8 @@ class ScaffoldSplitter(object): ...@@ -609,7 +617,8 @@ class ScaffoldSplitter(object):
Returns Returns
------- -------
list of 2-tuples list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set). Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
""" """
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k) assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
...@@ -677,7 +686,8 @@ class SingleTaskStratifiedSplitter(object): ...@@ -677,7 +686,8 @@ class SingleTaskStratifiedSplitter(object):
Returns Returns
------- -------
list of length 3 list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances. Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
""" """
train_val_test_sanity_check(frac_train, frac_val, frac_test) train_val_test_sanity_check(frac_train, frac_val, frac_test)
...@@ -735,7 +745,8 @@ class SingleTaskStratifiedSplitter(object): ...@@ -735,7 +745,8 @@ class SingleTaskStratifiedSplitter(object):
Returns Returns
------- -------
list of 2-tuples list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set). Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
""" """
if not isinstance(labels, np.ndarray): if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels) labels = F.asnumpy(labels)
......
...@@ -12,7 +12,7 @@ DGL works with the following operating systems: ...@@ -12,7 +12,7 @@ DGL works with the following operating systems:
* Windows 10 * Windows 10
DGL requires Python version 3.5 or later. Python 3.4 or earlier is not DGL requires Python version 3.5 or later. Python 3.4 or earlier is not
tested. Python 2 support is coming. tested.
DGL supports multiple tensor libraries as backends, e.g., PyTorch, MXNet. For requirements on backends and how to select one, see :ref:`backends`. DGL supports multiple tensor libraries as backends, e.g., PyTorch, MXNet. For requirements on backends and how to select one, see :ref:`backends`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment