Unverified Commit d3560b71 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Documentation (#1414)

* Update

* Update

* Update
parent 7e0893e6
......@@ -21,7 +21,6 @@ class Tox21(MoleculeCSVDataset):
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
......@@ -87,9 +86,18 @@ class Tox21(MoleculeCSVDataset):
def task_pos_weights(self):
"""Get weights for positive samples on each task
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
Returns
-------
numpy.ndarray
numpy array gives the weight of positive samples on all tasks
Tensor of dtype float32 and shape (T)
Weight of positive samples on all tasks
"""
return self._task_pos_weights
......@@ -363,15 +363,15 @@ class WLNReactionDataset(object):
Returns
-------
str
Reaction
Reaction.
str
Graph edits for the reaction
rdkit.Chem.rdchem.Mol
RDKit molecule instance
RDKit molecule instance for reactants
DGLGraph
DGLGraph for the ith molecular graph
DGLGraph for the ith molecular graph of reactants
DGLGraph
Complete DGLGraph, which will be needed for predicting
Complete DGLGraph for reactants, which will be needed for predicting
scores between each pair of atoms
float32 tensor of shape (V^2, 10)
Features for each pair of atoms.
......@@ -477,7 +477,6 @@ class USPTO(WLNReactionDataset):
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
......
......@@ -68,15 +68,20 @@ def load_pretrained(model_name, log=True):
model_name : str
Currently supported options include
* ``'GCN_Tox21'``
* ``'GAT_Tox21'``
* ``'AttentiveFP_Aromaticity'``
* ``'DGMG_ChEMBL_canonical'``
* ``'DGMG_ChEMBL_random'``
* ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'``
* ``'wln_center_uspto'``
* ``'GCN_Tox21'``: A GCN-based model for molecular property prediction on Tox21
* ``'GAT_Tox21'``: A GAT-based model for molecular property prediction on Tox21
* ``'AttentiveFP_Aromaticity'``: An AttentiveFP model for predicting number of
aromatic atoms on a subset of Pubmed
* ``'DGMG_ChEMBL_canonical'``: A DGMG model trained on ChEMBL with a canonical
atom order
* ``'DGMG_ChEMBL_random'``: A DGMG model trained on ChEMBL for molecule generation
with a random atom order
* ``'DGMG_ZINC_canonical'``: A DGMG model trained on ZINC for molecule generation
with a canonical atom order
* ``'DGMG_ZINC_random'``: A DGMG model pre-trained on ZINC for molecule generation
with a random atom order
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
log : bool
Whether to print progress for model loading
......
......@@ -22,7 +22,40 @@ class EarlyStopping(object):
The early stopping will happen if we do not observe performance
improvement for ``patience`` consecutive epochs.
filename : str or None
Filename for storing the model checkpoint
Filename for storing the model checkpoint. If not specified,
we will automatically generate a file starting with ``early_stop``
based on the current time.
Examples
--------
Below gives a demo for a fake training process.
>>> import torch
>>> import torch.nn as nn
>>> from torch.nn import MSELoss
>>> from torch.optim import Adam
>>> from dgllife.utils import EarlyStopping
>>> model = nn.Linear(1, 1)
>>> criterion = MSELoss()
>>> # For MSE, the lower, the better
>>> stopper = EarlyStopping(mode='lower', filename='test.pth')
>>> optimizer = Adam(params=model.parameters(), lr=1e-3)
>>> for epoch in range(1000):
>>> x = torch.randn(1, 1) # Fake input
>>> y = torch.randn(1, 1) # Fake label
>>> pred = model(x)
>>> loss = criterion(y, pred)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()
>>> early_stop = stopper.step(loss.detach().data, model)
>>> if early_stop:
>>> break
>>> # Load the final parameters saved by the model
>>> stopper.load_checkpoint(model)
"""
def __init__(self, mode='higher', patience=10, filename=None):
if filename is None:
......
......@@ -19,18 +19,44 @@ class Meter(object):
Currently we support evaluation with 4 metrics:
* pearson r2
* mae
* rmse
* roc auc score
* ``pearson r2``
* ``mae``
* ``rmse``
* ``roc auc score``
Parameters
----------
mean : torch.float32 tensor of shape (T) or None.
Mean of existing training labels across tasks if not None. T for the number of tasks.
Default to None.
Mean of existing training labels across tasks if not ``None``. ``T`` for the
number of tasks. Default to ``None`` and we assume no label normalization has been
performed.
std : torch.float32 tensor of shape (T)
Std of existing training labels across tasks if not None.
Std of existing training labels across tasks if not ``None``. Default to ``None``
and we assume no label normalization has been performed.
Examples
--------
Below gives a demo for a fake evaluation epoch.
>>> import torch
>>> from dgllife.utils import Meter
>>> meter = Meter()
>>> # Simulate 10 fake mini-batches
>>> for batch_id in range(10):
>>> batch_label = torch.randn(3, 3)
>>> batch_pred = torch.randn(3, 3)
>>> meter.update(batch_pred, batch_label)
>>> # Get MAE for all tasks
>>> print(meter.compute_metric('mae'))
[1.1325558423995972, 1.0543707609176636, 1.094650149345398]
>>> # Get MAE averaged over all tasks
>>> print(meter.compute_metric('mae', reduction='mean'))
1.0938589175542195
>>> # Get the sum of MAE over all tasks
>>> print(meter.compute_metric('mae', reduction='sum'))
3.2815767526626587
"""
def __init__(self, mean=None, std=None):
self.mask = []
......@@ -50,13 +76,13 @@ class Meter(object):
Parameters
----------
y_pred : float32 tensor
Predicted labels with shape (B, T),
B for number of graphs in the batch and T for the number of tasks
Predicted labels with shape ``(B, T)``,
``B`` for number of graphs in the batch and ``T`` for the number of tasks
y_true : float32 tensor
Ground truth labels with shape (B, T)
Ground truth labels with shape ``(B, T)``
mask : None or float32 tensor
Binary mask indicating the existence of ground truth labels with
shape (B, T). If None, we assume that all labels exist and create
shape ``(B, T)``. If None, we assume that all labels exist and create
a one-tensor for placeholder.
"""
self.y_pred.append(y_pred.detach().cpu())
......@@ -237,10 +263,10 @@ class Meter(object):
----------
metric_name : str
* 'r2': compute squared Pearson correlation coefficient
* 'mae': compute mean absolute error
* 'rmse': compute root mean square error
* 'roc_auc_score': compute roc-auc score
* ``'r2'``: compute squared Pearson correlation coefficient
* ``'mae'``: compute mean absolute error
* ``'rmse'``: compute root mean square error
* ``'roc_auc_score'``: compute roc-auc score
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
......
......@@ -64,6 +64,16 @@ def one_hot_encoding(x, allowable_set, encode_unknown=False):
List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
Examples
--------
>>> from dgllife.utils import one_hot_encoding
>>> one_hot_encoding('C', ['C', 'O'])
[True, False]
>>> one_hot_encoding('S', ['C', 'O'])
[False, False]
>>> one_hot_encoding('S', ['C', 'O'], encode_unknown=True)
[False, False, True]
"""
if encode_unknown and (allowable_set[-1] is not None):
allowable_set.append(None)
......@@ -98,6 +108,12 @@ def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atomic_number_one_hot
"""
if allowable_set is None:
allowable_set = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
......@@ -123,6 +139,12 @@ def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atom_type_one_hot
"""
if allowable_set is None:
allowable_set = list(range(1, 101))
......@@ -140,6 +162,11 @@ def atomic_number(atom):
-------
list
List containing one int only.
See Also
--------
atomic_number_one_hot
atom_type_one_hot
"""
return [atom.GetAtomicNum()]
......@@ -166,6 +193,9 @@ def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
See Also
--------
one_hot_encoding
atom_degree
atom_total_degree
atom_total_degree_one_hot
"""
if allowable_set is None:
......@@ -190,7 +220,9 @@ def atom_degree(atom):
See Also
--------
atom_degree_one_hot
atom_total_degree
atom_total_degree_one_hot
"""
return [atom.GetDegree()]
......@@ -209,7 +241,10 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
See Also
--------
one_hot_encoding
atom_degree
atom_degree_one_hot
atom_total_degree
"""
if allowable_set is None:
allowable_set = list(range(6))
......@@ -218,14 +253,16 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
def atom_total_degree(atom):
"""The degree of an atom including Hs.
See Also
--------
atom_degree
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree_one_hot
atom_degree
atom_degree_one_hot
"""
return [atom.GetTotalDegree()]
......@@ -246,6 +283,11 @@ def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_explicit_valence
"""
if allowable_set is None:
allowable_set = list(range(1, 7))
......@@ -263,6 +305,10 @@ def atom_explicit_valence(atom):
-------
list
List containing one int only.
See Also
--------
atom_explicit_valence_one_hot
"""
return [atom.GetExplicitValence()]
......@@ -283,6 +329,10 @@ def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_implicit_valence
"""
if allowable_set is None:
allowable_set = list(range(7))
......@@ -300,6 +350,10 @@ def atom_implicit_valence(atom):
------
list
List containing one int only.
See Also
--------
atom_implicit_valence_one_hot
"""
return [atom.GetImplicitValence()]
......@@ -323,6 +377,10 @@ def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.HybridizationType.SP,
......@@ -349,6 +407,11 @@ def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_total_num_H
"""
if allowable_set is None:
allowable_set = list(range(5))
......@@ -366,6 +429,10 @@ def atom_total_num_H(atom):
-------
list
List containing one int only.
See Also
--------
atom_total_num_H_one_hot
"""
return [atom.GetTotalNumHs()]
......@@ -386,6 +453,11 @@ def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_formal_charge
"""
if allowable_set is None:
allowable_set = list(range(-2, 3))
......@@ -403,6 +475,10 @@ def atom_formal_charge(atom):
-------
list
List containing one int only.
See Also
--------
atom_formal_charge_one_hot
"""
return [atom.GetFormalCharge()]
......@@ -423,6 +499,11 @@ def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_num_radical_electrons
"""
if allowable_set is None:
allowable_set = list(range(5))
......@@ -440,6 +521,10 @@ def atom_num_radical_electrons(atom):
-------
list
List containing one int only.
See Also
--------
atom_num_radical_electrons_one_hot
"""
return [atom.GetNumRadicalElectrons()]
......@@ -460,6 +545,11 @@ def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_aromatic
"""
if allowable_set is None:
allowable_set = [False, True]
......@@ -477,6 +567,10 @@ def atom_is_aromatic(atom):
-------
list
List containing one bool only.
See Also
--------
atom_is_aromatic_one_hot
"""
return [atom.GetIsAromatic()]
......@@ -497,6 +591,11 @@ def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_in_ring
"""
if allowable_set is None:
allowable_set = [False, True]
......@@ -514,6 +613,10 @@ def atom_is_in_ring(atom):
-------
list
List containing one bool only.
See Also
--------
atom_is_in_ring_one_hot
"""
return [atom.IsInRing()]
......@@ -529,6 +632,10 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
......@@ -589,7 +696,8 @@ class BaseAtomFeaturizer(object):
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
**We assume the resulting DGLGraph will not contain any virtual nodes and a node i in the
graph corresponds to exactly atom i in the molecule.**
Parameters
----------
......@@ -603,7 +711,7 @@ class BaseAtomFeaturizer(object):
Examples
--------
>>> from dgl.data.dgllife import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from dgllife.utils import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
......@@ -615,6 +723,16 @@ class BaseAtomFeaturizer(object):
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
>>> # Get feature size for atom mass
>>> print(atom_featurizer.feat_size('mass'))
1
>>> # Get feature size for atom degree
>>> print(atom_featurizer.feat_size('degree'))
11
See Also
--------
CanonicalAtomFeaturizer
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
......@@ -701,6 +819,38 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import CanonicalAtomFeaturizer
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')
>>> atom_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]])}
>>> # Get feature size for nodes
>>> print(atom_featurizer.feat_size('feat'))
74
See Also
--------
BaseAtomFeaturizer
"""
def __init__(self, atom_data_field='h'):
super(CanonicalAtomFeaturizer, self).__init__(
......@@ -734,6 +884,10 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondType.SINGLE,
......@@ -744,6 +898,7 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
......@@ -753,10 +908,16 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_conjugated
"""
if allowable_set is None:
allowable_set = [False, True]
......@@ -764,19 +925,26 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
def bond_is_conjugated(bond):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
See Also
--------
bond_is_conjugated_one_hot
"""
return [bond.GetIsConjugated()]
def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
......@@ -786,10 +954,16 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_in_ring
"""
if allowable_set is None:
allowable_set = [False, True]
......@@ -797,19 +971,26 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
def bond_is_in_ring(bond):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
See Also
--------
bond_is_in_ring_one_hot
"""
return [bond.IsInRing()]
def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
......@@ -822,10 +1003,15 @@ def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondStereo.STEREONONE,
......@@ -858,7 +1044,7 @@ class BaseBondFeaturizer(object):
Examples
--------
>>> from dgl.data.dgllife import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from dgllife.utils import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
......@@ -869,6 +1055,15 @@ class BaseBondFeaturizer(object):
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'ring': tensor([[0.], [0.], [0.], [0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
4
>>> bond_featurizer.feat_size('ring')
1
See Also
--------
CanonicalBondFeaturizer
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
......@@ -941,6 +1136,26 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Examples
--------
>>> from dgllife.utils import CanonicalBondFeaturizer
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = CanonicalBondFeaturizer(bond_data_field='feat')
>>> bond_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
12
See Also
--------
BaseBondFeaturizer
"""
def __init__(self, bond_data_field='e'):
super(CanonicalBondFeaturizer, self).__init__(
......
......@@ -21,6 +21,9 @@ __all__ = ['mol_to_graph',
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
This function can be used to construct any arbitrary ``DGLGraph`` from an
RDKit molecule instance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
......@@ -41,6 +44,12 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canon
-------
g : DGLGraph
Converted DGLGraph for the molecule
See Also
--------
mol_to_bigraph
mol_to_complete_graph
mol_to_nearest_neighbor_graph
"""
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
......@@ -132,6 +141,57 @@ def mol_to_bigraph(mol, add_self_loop=False,
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol, node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
smiles_to_bigraph
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer, canonical_atom_order)
......@@ -163,6 +223,54 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_bigraph
>>> g = smiles_to_bigraph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> g = smiles_to_bigraph('CCO', node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
mol_to_bigraph
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer,
......@@ -226,6 +334,66 @@ def mol_to_complete_graph(mol, add_self_loop=False,
-------
g : DGLGraph
Complete DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_complete_graph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> add_self_loop = True
>>> g = mol_to_complete_graph(
>>> mol, add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
smiles_to_complete_graph
"""
return mol_to_graph(mol,
partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
......@@ -258,6 +426,63 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
-------
g : DGLGraph
Complete DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_complete_graph
>>> g = smiles_to_complete_graph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> add_self_loop = True
>>> g = smiles_to_complete_graph(
>>> 'CCO', add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
mol_to_complete_graph
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer,
......@@ -297,6 +522,31 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
Destination nodes, corresponding to ``srcs``.
distances : list of float
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
Examples
--------
>>> from dgllife.utils import get_mol_3d_coordinates, k_nearest_neighbors
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> srcs, dsts, dists = k_nearest_neighbors(coords, neighbor_cutoff=1.25)
>>> print(srcs)
[8, 7, 11, 10, 20, 19]
>>> print(dsts)
[7, 8, 10, 11, 19, 20]
>>> print(dists)
[1.2084666104583117, 1.2084666104583117, 1.226457824344217,
1.226457824344217, 1.2230522248065987, 1.2230522248065987]
See Also
--------
get_mol_3d_coordinates
mol_to_nearest_neighbor_graph
smiles_to_nearest_neighbor_graph
"""
num_atoms = coordinates.shape[0]
model = NearestNeighbors(radius=neighbor_cutoff, p=p_distance)
......@@ -378,6 +628,45 @@ def mol_to_nearest_neighbor_graph(mol,
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import mol_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
smiles_to_nearest_neighbor_graph
"""
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
......@@ -463,6 +752,46 @@ def smiles_to_nearest_neighbor_graph(smiles,
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> smiles = 'CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C'
>>> mol = Chem.MolFromSmiles(smiles)
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = smiles_to_nearest_neighbor_graph(smiles, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
mol_to_nearest_neighbor_graph
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_nearest_neighbor_graph(
......
......@@ -16,6 +16,8 @@ __all__ = ['get_mol_3d_coordinates',
def get_mol_3d_coordinates(mol):
"""Get 3D coordinates of the molecule.
This function requires that molecular conformation has been initialized.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
......@@ -26,6 +28,27 @@ def get_mol_3d_coordinates(mol):
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
Examples
--------
An error will occur in the example below since the molecule object does not
carry conformation information.
>>> from rdkit import Chem
>>> from dgllife.utils import get_mol_3d_coordinates
>>> mol = Chem.MolFromSmiles('CCO')
Below we give a working example based on molecule conformation initialized from calculation.
>>> from rdkit.Chem import AllChem
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> print(coords)
array([[ 1.20967478, -0.25802181, 0. ],
[-0.05021255, 0.57068079, 0. ],
[-1.15946223, -0.31265898, 0. ]])
"""
try:
conf = mol.GetConformer()
......@@ -42,13 +65,13 @@ def get_mol_3d_coordinates(mol):
# pylint: disable=E1101
def load_molecule(molecule_file, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file.
"""Load a molecule from a file of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format '.mol2', '.sdf',
'.pdbqt', or '.pdb'.
Path to file for storing a molecule, which can be of format ``.mol2`` or ``.sdf``
or ``.pdbqt`` or ``.pdb``.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
......@@ -115,13 +138,14 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
def multiprocess_load_molecules(files, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing.
"""Load molecules from files with multiprocessing, which can be of format ``.mol2`` or
``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format '.mol2',
'.sdf', '.pdbqt', or '.pdb'.
Each element is a path to a file storing a molecule, which can be of format ``.mol2``,
``.sdf``, ``.pdbqt``, or ``.pdb``.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
......
......@@ -42,7 +42,8 @@ def base_k_fold_split(split_method, dataset, k, log):
Returns
-------
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple (train_set, val_set),
which are all :class:`Subset` instances.
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
all_folds = []
......@@ -208,7 +209,8 @@ class ConsecutiveSplitter(object):
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test that also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
......@@ -229,7 +231,8 @@ class ConsecutiveSplitter(object):
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
return base_k_fold_split(ConsecutiveSplitter.train_val_test_split, dataset, k, log)
......@@ -269,7 +272,8 @@ class RandomSplitter(object):
Returns
-------
list of length 3
Subsets for training, validation and test.
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test],
shuffle=True, random_state=random_state)
......@@ -298,7 +302,8 @@ class RandomSplitter(object):
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
......@@ -381,7 +386,8 @@ class MolecularWeightSplitter(object):
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
......@@ -422,7 +428,8 @@ class MolecularWeightSplitter(object):
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
......@@ -543,7 +550,8 @@ class ScaffoldSplitter(object):
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test, which also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
......@@ -609,7 +617,8 @@ class ScaffoldSplitter(object):
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
......@@ -677,7 +686,8 @@ class SingleTaskStratifiedSplitter(object):
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
train_val_test_sanity_check(frac_train, frac_val, frac_test)
......@@ -735,7 +745,8 @@ class SingleTaskStratifiedSplitter(object):
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
......
......@@ -12,7 +12,7 @@ DGL works with the following operating systems:
* Windows 10
DGL requires Python version 3.5 or later. Python 3.4 or earlier is not
tested. Python 2 support is coming.
tested.
DGL supports multiple tensor libraries as backends, e.g., PyTorch, MXNet. For requirements on backends and how to select one, see :ref:`backends`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment