Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
[MASTER]
# Adapted from github.com/dmlc/dgl/tests/lint/pylintrc
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS,_cy2,_cy3,backend,data,contrib
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=4
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=design,
similarities,
no-self-use,
attribute-defined-outside-init,
locally-disabled,
star-args,
pointless-except,
bad-option-value,
global-statement,
fixme,
suppressed-message,
useless-suppression,
locally-enabled,
import-error,
unsubscriptable-object,
unbalanced-tuple-unpacking,
protected-access,
useless-object-inheritance,
no-else-return,
len-as-condition,
cyclic-import, # disabled due to the inevitable dgl.graph -> dgl.subgraph loop
undefined-variable, # disabled due to C extension (should enable)
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,j,k,u,v,e,n,m,w,x,y,g,G,hg,fn,ex,Run,_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
#variable-rgx=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=4000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=dgl.backend,dgl._api_internal
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=yes
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception".
overgeneral-exceptions=Exception
\ No newline at end of file
import dgl
import os
import shutil
import torch
from dgl.data.utils import _get_dgl_url, download, extract_archive
from dgllife.model.model_zoo.acnn import ACNN
from dgllife.utils.complex_to_graph import ACNN_graph_construction_and_featurization
from dgllife.utils.rdkit_utils import load_molecule
def remove_dir(dir):
if os.path.isdir(dir):
try:
shutil.rmtree(dir)
except OSError:
pass
def test_acnn():
remove_dir('tmp1')
remove_dir('tmp2')
url = _get_dgl_url('dgllife/example_mols.tar.gz')
local_path = 'tmp1/example_mols.tar.gz'
download(url, path=local_path)
extract_archive(local_path, 'tmp2')
pocket_mol, pocket_coords = load_molecule(
'tmp2/example_mols/example.pdb', remove_hs=True)
ligand_mol, ligand_coords = load_molecule(
'tmp2/example_mols/example.pdbqt', remove_hs=True)
remove_dir('tmp1')
remove_dir('tmp2')
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g1 = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords)
model = ACNN()
model.to(device)
g1.to(device)
assert model(g1).shape == torch.Size([1, 1])
bg = dgl.batch_hetero([g1, g1])
bg.to(device)
assert model(bg).shape == torch.Size([2, 1])
model = ACNN(hidden_sizes=[1, 2],
weight_init_stddevs=[1, 1],
dropouts=[0.1, 0.],
features_to_use=torch.tensor([6., 8.]),
radial=[[12.0], [0.0, 2.0], [4.0]])
model.to(device)
g1.to(device)
assert model(g1).shape == torch.Size([1, 1])
bg = dgl.batch_hetero([g1, g1])
bg.to(device)
assert model(bg).shape == torch.Size([2, 1])
if __name__ == '__main__':
test_acnn()
import torch
from rdkit import Chem
from dgllife.model import DGMG, DGLJTNNVAE
def test_dgmg():
model = DGMG(atom_types=['O', 'Cl', 'C', 'S', 'F', 'Br', 'N'],
bond_types=[Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE],
node_hidden_size=1,
num_prop_rounds=1,
dropout=0.2)
assert model(
actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)], rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
model = DGMG(atom_types=['O', 'Cl', 'C', 'S', 'F', 'Br', 'N'],
bond_types=[Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE])
assert model(
actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)], rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
def test_jtnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
model = DGLJTNNVAE(hidden_size=1,
latent_size=2,
depth=1).to(device)
if __name__ == '__main__':
test_dgmg()
test_jtnn()
import dgl
import torch
import torch.nn.functional as F
from dgl import DGLGraph
from dgllife.model.gnn import *
def test_graph1():
"""Graph with node features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1)
def test_graph2():
"""Batched graph with node features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1)
def test_graph3():
"""Graph with node and edge features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * g.number_of_edges()).float().reshape(-1, 2)
def test_graph4():
"""Batched graph with node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * bg.number_of_edges()).float().reshape(-1, 2)
def test_graph5():
"""Graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.randn(3, 1)
def test_graph6():
"""Batched graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
def test_graph7():
"""Graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.LongTensor([2, 3, 4]), \
torch.LongTensor([0, 0, 1]), torch.LongTensor([2, 3, 2])
def test_graph8():
"""Batched graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 1, 0, 2, 2]), \
torch.LongTensor([2, 3, 4, 1, 0, 1, 2, 2]), \
torch.LongTensor([0, 0, 1, 2, 1, 0, 0]), \
torch.LongTensor([2, 3, 2, 0, 1, 2, 1])
def test_gcn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gnn = GCN(in_feats=1).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 64])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 64])
# Test configured setting
gnn = GCN(in_feats=1,
hidden_feats=[1, 1],
activation=[F.relu, F.relu],
residual=[True, True],
batchnorm=[True, True],
dropout=[0.2, 0.2]).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 1])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 1])
def test_gat():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gnn = GAT(in_feats=1).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 32])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 32])
# Test configured setting
gnn = GAT(in_feats=1,
hidden_feats=[1, 1],
num_heads=[2, 3],
feat_drops=[0.1, 0.1],
attn_drops=[0.1, 0.1],
alphas=[0.2, 0.2],
residuals=[True, True],
agg_modes=['flatten', 'mean'],
activations=[None, F.elu]).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 1])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 1])
gnn = GAT(in_feats=1,
hidden_feats=[1, 1],
num_heads=[2, 3],
feat_drops=[0.1, 0.1],
attn_drops=[0.1, 0.1],
alphas=[0.2, 0.2],
residuals=[True, True],
agg_modes=['mean', 'flatten'],
activations=[None, F.elu]).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 3])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 3])
def test_attentive_fp_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test AttentiveFPGNN
gnn = AttentiveFPGNN(node_feat_size=1,
edge_feat_size=2,
num_layers=1,
graph_feat_size=1,
dropout=0.).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 1])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 1])
def test_schnet_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
gnn = SchNetGNN().to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 64])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 64])
# Test configured setting
gnn = SchNetGNN(num_node_types=5,
node_feats=2,
hidden_feats=[3],
cutoff=0.3).to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 2])
def test_mgcn_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
gnn = MGCNGNN().to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 512])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 512])
# Test configured setting
gnn = MGCNGNN(feats=2,
n_layers=2,
num_node_types=5,
num_edge_types=150,
cutoff=0.3).to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 6])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 6])
def test_mpnn_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
gnn = MPNNGNN(node_in_feats=1,
edge_in_feats=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 64])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 64])
# Test configured setting
gnn = MPNNGNN(node_in_feats=1,
edge_in_feats=2,
node_out_feats=2,
edge_hidden_feats=2,
num_step_message_passing=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 2])
def test_wln():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
gnn = WLN(node_in_feats=1,
edge_in_feats=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 300])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 300])
# Test configured setting
gnn = WLN(node_in_feats=1,
edge_in_feats=2,
node_out_feats=3,
n_layers=1).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 3])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 3])
def test_weave():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
gnn = WeaveGNN(node_in_feats=1,
edge_in_feats=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 50])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 50])
# Test configured setting
gnn = WeaveGNN(node_in_feats=1,
edge_in_feats=2,
num_layers=1,
hidden_feats=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 2])
def test_gin():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats1, node_feats2, edge_feats1, edge_feats2 = test_graph7()
node_feats1, node_feats2 = node_feats1.to(device), node_feats2.to(device)
edge_feats1, edge_feats2 = edge_feats1.to(device), edge_feats2.to(device)
bg, batch_node_feats1, batch_node_feats2, \
batch_edge_feats1, batch_edge_feats2 = test_graph8()
batch_node_feats1, batch_node_feats2 = batch_node_feats1.to(device), \
batch_node_feats2.to(device)
batch_edge_feats1, batch_edge_feats2 = batch_edge_feats1.to(device), \
batch_edge_feats2.to(device)
# Test default setting
gnn = GIN(num_node_emb_list=[3, 5], num_edge_emb_list=[3, 4]).to(device)
assert gnn(g, [node_feats1, node_feats2], [edge_feats1, edge_feats2]).shape \
== torch.Size([3, 300])
assert gnn(bg, [batch_node_feats1, batch_node_feats2],
[batch_edge_feats1, batch_edge_feats2]).shape == torch.Size([8, 300])
# Test configured setting
gnn = GIN(num_node_emb_list=[3, 5], num_edge_emb_list=[3, 4],
num_layers=2, emb_dim=10, JK='concat', dropout=0.1).to(device)
assert gnn(g, [node_feats1, node_feats2], [edge_feats1, edge_feats2]).shape \
== torch.Size([3, 30])
assert gnn(bg, [batch_node_feats1, batch_node_feats2],
[batch_edge_feats1, batch_edge_feats2]).shape == torch.Size([8, 30])
if __name__ == '__main__':
test_gcn()
test_gat()
test_attentive_fp_gnn()
test_schnet_gnn()
test_mgcn_gnn()
test_mpnn_gnn()
test_wln()
test_weave()
test_gin()
import dgl
import os
import torch
from functools import partial
from dgllife.model import load_pretrained
from dgllife.utils import *
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def run_dgmg_ChEMBL(model):
assert model(
actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)],
rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
def run_dgmg_ZINC(model):
assert model(
actions=[(0, 2), (1, 3), (0, 5), (1, 0), (2, 0), (1, 3), (0, 9)],
rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
def test_dgmg():
model = load_pretrained('DGMG_ZINC_canonical')
run_dgmg_ZINC(model)
model = load_pretrained('DGMG_ZINC_random')
run_dgmg_ZINC(model)
model = load_pretrained('DGMG_ChEMBL_canonical')
run_dgmg_ChEMBL(model)
model = load_pretrained('DGMG_ChEMBL_random')
run_dgmg_ChEMBL(model)
remove_file('DGMG_ChEMBL_canonical_pre_trained.pth')
remove_file('DGMG_ChEMBL_random_pre_trained.pth')
remove_file('DGMG_ZINC_canonical_pre_trained.pth')
remove_file('DGMG_ZINC_random_pre_trained.pth')
def test_jtnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
model = load_pretrained('JTNN_ZINC').to(device)
remove_file('JTNN_ZINC_pre_trained.pth')
def test_gcn_tox21():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = CanonicalAtomFeaturizer()
g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer)
g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer)
bg = dgl.batch([g1, g2])
model = load_pretrained('GCN_Tox21').to(device)
model(bg.to(device), bg.ndata.pop('h').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('h').to(device))
remove_file('GCN_Tox21_pre_trained.pth')
def test_gat_tox21():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = CanonicalAtomFeaturizer()
g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer)
g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer)
bg = dgl.batch([g1, g2])
model = load_pretrained('GAT_Tox21').to(device)
model(bg.to(device), bg.ndata.pop('h').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('h').to(device))
remove_file('GAT_Tox21_pre_trained.pth')
def test_weave_tox21():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = WeaveAtomFeaturizer()
edge_featurizer = WeaveEdgeFeaturizer(max_distance=2)
g1 = smiles_to_complete_graph('CO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, add_self_loop=True)
g2 = smiles_to_complete_graph('CCO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, add_self_loop=True)
bg = dgl.batch([g1, g2])
model = load_pretrained('Weave_Tox21').to(device)
model(bg.to(device), bg.ndata.pop('h').to(device), bg.edata.pop('e').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('h').to(device), g1.edata.pop('e').to(device))
remove_file('Weave_Tox21_pre_trained.pth')
def chirality(atom):
try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
[atom.HasProp('_ChiralityPossible')]
except:
return [False, False] + [atom.HasProp('_ChiralityPossible')]
def test_attentivefp_aromaticity():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_formal_charge, atom_num_radical_electrons,
partial(atom_hybridization_one_hot, encode_unknown=True),
lambda atom: [0], # A placeholder for aromatic information,
atom_total_num_H_one_hot, chirality
],
)}
)
edge_featurizer = BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)]
})
g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
bg = dgl.batch([g1, g2])
model = load_pretrained('AttentiveFP_Aromaticity').to(device)
model(bg.to(device), bg.ndata.pop('hv').to(device), bg.edata.pop('he').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('hv').to(device), g1.edata.pop('he').to(device))
remove_file('AttentiveFP_Aromaticity_pre_trained.pth')
if __name__ == '__main__':
test_dgmg()
test_jtnn()
test_gcn_tox21()
test_gat_tox21()
test_attentivefp_aromaticity()
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgllife.model.model_zoo import *
def test_graph1():
"""Graph with node features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1)
def test_graph2():
"""Batched graph with node features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1)
def test_graph3():
"""Graph with node features and edge features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * g.number_of_edges()).float().reshape(-1, 2)
def test_graph4():
"""Batched graph with node features and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * bg.number_of_edges()).float().reshape(-1, 2)
def test_graph5():
"""Graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.randn(3, 1)
def test_graph6():
"""Batched graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
def test_graph7():
"""Graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.LongTensor([2, 3, 4]), \
torch.LongTensor([0, 0, 1]), torch.LongTensor([2, 3, 2])
def test_graph8():
"""Batched graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 1, 0, 2, 2]), \
torch.LongTensor([2, 3, 4, 1, 0, 1, 2, 2]), \
torch.LongTensor([0, 0, 1, 2, 1, 0, 0]), \
torch.LongTensor([2, 3, 2, 0, 1, 2, 1])
def test_mlp_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g_feats = torch.tensor([[1.], [2.]]).to(device)
mlp_predictor = MLPPredictor(in_feats=1, hidden_feats=1, n_tasks=2).to(device)
assert mlp_predictor(g_feats).shape == torch.Size([2, 2])
def test_gcn_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gcn_predictor = GCNPredictor(in_feats=1).to(device)
gcn_predictor.eval()
assert gcn_predictor(g, node_feats).shape == torch.Size([1, 1])
gcn_predictor.train()
assert gcn_predictor(bg, batch_node_feats).shape == torch.Size([2, 1])
# Test configured setting
gcn_predictor = GCNPredictor(in_feats=1,
hidden_feats=[1],
activation=[F.relu],
residual=[True],
batchnorm=[True],
dropout=[0.1],
classifier_hidden_feats=1,
classifier_dropout=0.1,
n_tasks=2).to(device)
gcn_predictor.eval()
assert gcn_predictor(g, node_feats).shape == torch.Size([1, 2])
gcn_predictor.train()
assert gcn_predictor(bg, batch_node_feats).shape == torch.Size([2, 2])
def test_gat_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gat_predictor = GATPredictor(in_feats=1).to(device)
gat_predictor.eval()
assert gat_predictor(g, node_feats).shape == torch.Size([1, 1])
gat_predictor.train()
assert gat_predictor(bg, batch_node_feats).shape == torch.Size([2, 1])
# Test configured setting
gat_predictor = GATPredictor(in_feats=1,
hidden_feats=[1, 2],
num_heads=[2, 3],
feat_drops=[0.1, 0.1],
attn_drops=[0.1, 0.1],
alphas=[0.1, 0.1],
residuals=[True, True],
agg_modes=['mean', 'flatten'],
activations=[None, F.elu]).to(device)
gat_predictor.eval()
assert gat_predictor(g, node_feats).shape == torch.Size([1, 1])
gat_predictor.train()
assert gat_predictor(bg, batch_node_feats).shape == torch.Size([2, 1])
def test_attentivefp_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
attentivefp_predictor = AttentiveFPPredictor(node_feat_size=1,
edge_feat_size=2,
num_layers=2,
num_timesteps=1,
graph_feat_size=1,
n_tasks=2).to(device)
assert attentivefp_predictor(g, node_feats, edge_feats).shape == torch.Size([1, 2])
assert attentivefp_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2])
def test_schnet_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
schnet_predictor = SchNetPredictor().to(device)
assert schnet_predictor(g, node_types, edge_dists).shape == torch.Size([1, 1])
assert schnet_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 1])
# Test configured setting
schnet_predictor = SchNetPredictor(node_feats=2,
hidden_feats=[2, 2],
classifier_hidden_feats=3,
n_tasks=3,
num_node_types=5,
cutoff=0.3).to(device)
assert schnet_predictor(g, node_types, edge_dists).shape == torch.Size([1, 3])
assert schnet_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 3])
def test_mgcn_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
mgcn_predictor = MGCNPredictor().to(device)
assert mgcn_predictor(g, node_types, edge_dists).shape == torch.Size([1, 1])
assert mgcn_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 1])
# Test configured setting
mgcn_predictor = MGCNPredictor(feats=2,
n_layers=2,
classifier_hidden_feats=3,
n_tasks=3,
num_node_types=5,
num_edge_types=150,
cutoff=0.3).to(device)
assert mgcn_predictor(g, node_types, edge_dists).shape == torch.Size([1, 3])
assert mgcn_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 3])
def test_mpnn_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
mpnn_predictor = MPNNPredictor(node_in_feats=1,
edge_in_feats=2).to(device)
assert mpnn_predictor(g, node_feats, edge_feats).shape == torch.Size([1, 1])
assert mpnn_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 1])
# Test configured setting
mpnn_predictor = MPNNPredictor(node_in_feats=1,
edge_in_feats=2,
node_out_feats=2,
edge_hidden_feats=2,
n_tasks=2,
num_step_message_passing=2,
num_step_set2set=2,
num_layer_set2set=2).to(device)
assert mpnn_predictor(g, node_feats, edge_feats).shape == torch.Size([1, 2])
assert mpnn_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2])
def test_weave_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
weave_predictor = WeavePredictor(node_in_feats=1,
edge_in_feats=2).to(device)
assert weave_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 1])
# Test configured setting
weave_predictor = WeavePredictor(node_in_feats=1,
edge_in_feats=2,
num_gnn_layers=2,
gnn_hidden_feats=10,
gnn_activation=F.relu,
graph_feats=128,
gaussian_expand=True,
gaussian_memberships=None,
readout_activation=nn.Tanh(),
n_tasks=2).to(device)
assert weave_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2])
def test_gin_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats1, node_feats2, edge_feats1, edge_feats2 = test_graph7()
node_feats1, node_feats2 = node_feats1.to(device), node_feats2.to(device)
edge_feats1, edge_feats2 = edge_feats1.to(device), edge_feats2.to(device)
bg, batch_node_feats1, batch_node_feats2, \
batch_edge_feats1, batch_edge_feats2 = test_graph8()
batch_node_feats1, batch_node_feats2 = batch_node_feats1.to(device), \
batch_node_feats2.to(device)
batch_edge_feats1, batch_edge_feats2 = batch_edge_feats1.to(device), \
batch_edge_feats2.to(device)
num_node_emb_list = [3, 5]
num_edge_emb_list = [3, 4]
for JK in ['concat', 'last', 'max', 'sum']:
for readout in ['sum', 'mean', 'max', 'attention']:
model = GINPredictor(num_node_emb_list=num_node_emb_list,
num_edge_emb_list=num_edge_emb_list,
num_layers=2,
emb_dim=10,
JK=JK,
readout=readout,
n_tasks=2).to(device)
assert model(g, [node_feats1, node_feats2], [edge_feats1, edge_feats2]).shape \
== torch.Size([1, 2])
assert model(bg, [batch_node_feats1, batch_node_feats2],
[batch_edge_feats1, batch_edge_feats2]).shape == torch.Size([2, 2])
if __name__ == '__main__':
test_mlp_predictor()
test_gcn_predictor()
test_gat_predictor()
test_attentivefp_predictor()
test_schnet_predictor()
test_mgcn_predictor()
test_mpnn_predictor()
test_weave_predictor()
test_gin_predictor()
import dgl
import numpy as np
import torch
from dgl import DGLGraph
from dgllife.model.model_zoo import *
def get_complete_graph(num_nodes):
edge_list = []
for i in range(num_nodes):
for j in range(num_nodes):
edge_list.append((i, j))
return DGLGraph(edge_list)
def test_graph1():
"""
Bi-directed graphs and complete graphs for the molecules.
In addition to node features/edge features, we also return
features for the pairs of nodes.
"""
mol_graph = DGLGraph([(0, 1), (0, 2), (1, 2)])
node_feats = torch.arange(mol_graph.number_of_nodes()).float().reshape(-1, 1)
edge_feats = torch.arange(2 * mol_graph.number_of_edges()).float().reshape(-1, 2)
complete_graph = get_complete_graph(mol_graph.number_of_nodes())
atom_pair_feats = torch.arange(complete_graph.number_of_edges()).float().reshape(-1, 1)
return mol_graph, node_feats, edge_feats, complete_graph, atom_pair_feats
def test_graph2():
"""Batched version of test_graph1"""
mol_graph1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
mol_graph2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
batch_mol_graph = dgl.batch([mol_graph1, mol_graph2])
node_feats = torch.arange(batch_mol_graph.number_of_nodes()).float().reshape(-1, 1)
edge_feats = torch.arange(2 * batch_mol_graph.number_of_edges()).float().reshape(-1, 2)
complete_graph1 = get_complete_graph(mol_graph1.number_of_nodes())
complete_graph2 = get_complete_graph(mol_graph2.number_of_nodes())
batch_complete_graph = dgl.batch([complete_graph1, complete_graph2])
atom_pair_feats = torch.arange(batch_complete_graph.number_of_edges()).float().reshape(-1, 1)
return batch_mol_graph, node_feats, edge_feats, batch_complete_graph, atom_pair_feats
def test_wln_reaction_center():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
mol_graph, node_feats, edge_feats, complete_graph, atom_pair_feats = test_graph1()
mol_graph = mol_graph.to(device)
node_feats, edge_feats = node_feats.to(device), edge_feats.to(device)
complete_graph = complete_graph.to(device)
atom_pair_feats = atom_pair_feats.to(device)
batch_mol_graph, batch_node_feats, batch_edge_feats, batch_complete_graph, \
batch_atom_pair_feats = test_graph2()
batch_mol_graph = batch_mol_graph.to(device)
batch_node_feats, batch_edge_feats = batch_node_feats.to(device), batch_edge_feats.to(device)
batch_complete_graph = batch_complete_graph.to(device)
batch_atom_pair_feats = batch_atom_pair_feats.to(device)
# Test default setting
model = WLNReactionCenter(node_in_feats=1,
edge_in_feats=2,
node_pair_in_feats=1).to(device)
assert model(mol_graph, complete_graph, node_feats, edge_feats, atom_pair_feats)[0].shape == \
torch.Size([complete_graph.number_of_edges(), 5])
assert model(batch_mol_graph, batch_complete_graph, batch_node_feats,
batch_edge_feats, batch_atom_pair_feats)[0].shape == \
torch.Size([batch_complete_graph.number_of_edges(), 5])
# Test configured setting
model = WLNReactionCenter(node_in_feats=1,
edge_in_feats=2,
node_pair_in_feats=1,
node_out_feats=1,
n_layers=1,
n_tasks=1).to(device)
assert model(mol_graph, complete_graph, node_feats, edge_feats, atom_pair_feats)[0].shape == \
torch.Size([complete_graph.number_of_edges(), 1])
assert model(batch_mol_graph, batch_complete_graph, batch_node_feats,
batch_edge_feats, batch_atom_pair_feats)[0].shape == \
torch.Size([batch_complete_graph.number_of_edges(), 1])
def test_reactant_product_graph(batch_size, device):
edges = (np.array([0, 1, 2]), np.array([1, 2, 2]))
reactant_g = []
for _ in range(batch_size):
reactant_g.append(DGLGraph(edges))
reactant_g = dgl.batch(reactant_g)
reactant_node_feats = torch.arange(
reactant_g.number_of_nodes()).float().reshape(-1, 1).to(device)
reactant_edge_feats = torch.arange(
reactant_g.number_of_edges()).float().reshape(-1, 1).to(device)
product_g = []
batch_num_candidate_products = []
for i in range(1, batch_size + 1):
product_g.extend([
DGLGraph(edges) for _ in range(i)
])
batch_num_candidate_products.append(i)
product_g = dgl.batch(product_g)
product_node_feats = torch.arange(
product_g.number_of_nodes()).float().reshape(-1, 1).to(device)
product_edge_feats = torch.arange(
product_g.number_of_edges()).float().reshape(-1, 1).to(device)
product_scores = torch.randn(sum(batch_num_candidate_products), 1).to(device)
return reactant_g, reactant_node_feats, reactant_edge_feats, product_g, product_node_feats, \
product_edge_feats, product_scores, batch_num_candidate_products
def test_wln_candidate_ranking():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
reactant_g, reactant_node_feats, reactant_edge_feats, product_g, product_node_feats, \
product_edge_feats, product_scores, num_candidate_products = \
test_reactant_product_graph(batch_size=1, device=device)
batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats, batch_product_g, \
batch_product_node_feats, batch_product_edge_feats, batch_product_scores, \
batch_num_candidate_products = test_reactant_product_graph(batch_size=2, device=device)
# Test default setting
model = WLNReactionRanking(node_in_feats=1,
edge_in_feats=1).to(device)
assert model(reactant_g, reactant_node_feats, reactant_edge_feats, product_g,
product_node_feats, product_edge_feats, product_scores,
num_candidate_products).shape == torch.Size([sum(num_candidate_products), 1])
assert model(batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats,
batch_product_g, batch_product_node_feats, batch_product_edge_feats,
batch_product_scores, batch_num_candidate_products).shape == \
torch.Size([sum(batch_num_candidate_products), 1])
model = WLNReactionRanking(node_in_feats=1,
edge_in_feats=1,
node_hidden_feats=100,
num_encode_gnn_layers=2).to(device)
assert model(reactant_g, reactant_node_feats, reactant_edge_feats, product_g,
product_node_feats, product_edge_feats, product_scores,
num_candidate_products).shape == torch.Size([sum(num_candidate_products), 1])
assert model(batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats,
batch_product_g, batch_product_node_feats, batch_product_edge_feats,
batch_product_scores, batch_num_candidate_products).shape == \
torch.Size([sum(batch_num_candidate_products), 1])
if __name__ == '__main__':
test_wln_reaction_center()
test_wln_candidate_ranking()
import dgl
import torch
import torch.nn.functional as F
from dgl import DGLGraph
from dgllife.model.readout import *
def test_graph1():
"""Graph with node features"""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1)
def test_graph2():
"Batched graph with node features"
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1)
def test_weighted_sum_and_max():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = WeightedSumAndMax(in_feats=1).to(device)
assert model(g, node_feats).shape == torch.Size([1, 2])
assert model(bg, batch_node_feats).shape == torch.Size([2, 2])
def test_attentive_fp_readout():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = AttentiveFPReadout(feat_size=1,
num_timesteps=1).to(device)
assert model(g, node_feats).shape == torch.Size([1, 1])
assert model(bg, batch_node_feats).shape == torch.Size([2, 1])
def test_mlp_readout():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = MLPNodeReadout(node_feats=1,
hidden_feats=2,
graph_feats=3,
activation=F.relu,
mode='sum').to(device)
assert model(g, node_feats).shape == torch.Size([1, 3])
assert model(bg, batch_node_feats).shape == torch.Size([2, 3])
model = MLPNodeReadout(node_feats=1,
hidden_feats=2,
graph_feats=3,
mode='max').to(device)
assert model(g, node_feats).shape == torch.Size([1, 3])
assert model(bg, batch_node_feats).shape == torch.Size([2, 3])
model = MLPNodeReadout(node_feats=1,
hidden_feats=2,
graph_feats=3,
mode='mean').to(device)
assert model(g, node_feats).shape == torch.Size([1, 3])
assert model(bg, batch_node_feats).shape == torch.Size([2, 3])
def test_weave_readout():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = WeaveGather(node_in_feats=1).to(device)
assert model(g, node_feats).shape == torch.Size([1, 1])
assert model(bg, batch_node_feats).shape == torch.Size([2, 1])
model = WeaveGather(node_in_feats=1, gaussian_expand=False).to(device)
assert model(g, node_feats).shape == torch.Size([1, 1])
assert model(bg, batch_node_feats).shape == torch.Size([2, 1])
if __name__ == '__main__':
test_weighted_sum_and_max()
test_attentive_fp_readout()
test_mlp_readout()
test_weave_readout()
#!/bin/bash
# Argument
# - dev: cpu or gpu
if [ $# -ne 1 ]; then
echo "Device argument required, can be cpu or gpu"
exit -1
fi
dev=$1
set -e
. /opt/conda/etc/profile.d/conda.sh
rm -rf _deps
mkdir _deps
pushd _deps
conda activate "pytorch-ci"
if [ "$dev" == "gpu" ]; then
pip uninstall -y dgl
pip install --pre dgl
python3 setup.py install
else
pip uninstall -y dgl-cu101
pip install --pre dgl-cu101
python3 setup.py install
fi
popd
\ No newline at end of file
#!/bin/bash
# Adapted from github.com/dmlc/dgl/tests/scripts/task_lint.sh
# pylint
echo 'Checking code style of python codes...'
python3 -m pylint --reports=y -v --rcfile=tests/lint/pylintrc python/dgllife || exit 1
\ No newline at end of file
#!/bin/bash
. /opt/conda/etc/profile.d/conda.sh
function fail {
echo FAIL: $@
exit -1
}
function usage {
echo "Usage: $0 backend device"
}
if [ $# -ne 2 ]; then
usage
fail "Error: must specify backend and device"
fi
export DGLBACKEND=$1
export PYTHONPATH=${PWD}/python:$PYTHONPATH
export DGL_DOWNLOAD_DIR=${PWD}
if [ $2 == "gpu" ]
then
export CUDA_VISIBLE_DEVICES=0
else
export CUDA_VISIBLE_DEVICES=-1
fi
conda activate ${DGLBACKEND}-ci
pip install _deps/dgl*.whl
python3 -m pytest -v --junitxml=pytest_data.xml tests/data || fail "data"
python3 -m pytest -v --junitxml=pytest_model.xml tests/model || fail "model"
python3 -m pytest -v --junitxml=pytest_utils.xml tests/utils || fail "utils"
\ No newline at end of file
import os
import shutil
import torch
from dgl.data.utils import download, _get_dgl_url, extract_archive
from dgllife.utils.complex_to_graph import *
from dgllife.utils.rdkit_utils import load_molecule
def remove_dir(dir):
if os.path.isdir(dir):
try:
shutil.rmtree(dir)
except OSError:
pass
def test_acnn_graph_construction_and_featurization():
remove_dir('tmp1')
remove_dir('tmp2')
url = _get_dgl_url('dgllife/example_mols.tar.gz')
local_path = 'tmp1/example_mols.tar.gz'
download(url, path=local_path)
extract_archive(local_path, 'tmp2')
pocket_mol, pocket_coords = load_molecule(
'tmp2/example_mols/example.pdb', remove_hs=True)
ligand_mol, ligand_coords = load_molecule(
'tmp2/example_mols/example.pdbqt', remove_hs=True)
pocket_mol_with_h, pocket_coords_with_h = load_molecule(
'tmp2/example_mols/example.pdb', remove_hs=False)
remove_dir('tmp1')
remove_dir('tmp2')
# Test default case
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords)
assert set(g.ntypes) == set(['protein_atom', 'ligand_atom'])
assert set(g.etypes) == set(['protein', 'ligand', 'complex', 'complex', 'complex', 'complex'])
assert g.number_of_nodes('protein_atom') == 286
assert g.number_of_nodes('ligand_atom') == 21
assert g.number_of_edges('protein') == 3432
assert g.number_of_edges('ligand') == 252
assert g.number_of_edges(('protein_atom', 'complex', 'protein_atom')) == 3349
assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 131
assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 121
assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 83
assert 'atomic_number' in g.nodes['protein_atom'].data
assert 'atomic_number' in g.nodes['ligand_atom'].data
assert torch.allclose(g.nodes['protein_atom'].data['mask'],
torch.ones(g.number_of_nodes('protein_atom'), 1))
assert torch.allclose(g.nodes['ligand_atom'].data['mask'],
torch.ones(g.number_of_nodes('ligand_atom'), 1))
assert 'distance' in g.edges['protein'].data
assert 'distance' in g.edges['ligand'].data
assert 'distance' in g.edges[('protein_atom', 'complex', 'protein_atom')].data
assert 'distance' in g.edges[('ligand_atom', 'complex', 'ligand_atom')].data
assert 'distance' in g.edges[('protein_atom', 'complex', 'ligand_atom')].data
assert 'distance' in g.edges[('ligand_atom', 'complex', 'protein_atom')].data
# Test max_num_ligand_atoms and max_num_protein_atoms
max_num_ligand_atoms = 30
max_num_protein_atoms = 300
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords,
max_num_ligand_atoms=max_num_ligand_atoms,
max_num_protein_atoms=max_num_protein_atoms)
assert g.number_of_nodes('ligand_atom') == max_num_ligand_atoms
assert g.number_of_nodes('protein_atom') == max_num_protein_atoms
ligand_mask = torch.zeros(max_num_ligand_atoms, 1)
ligand_mask[:ligand_mol.GetNumAtoms(), :] = 1.
assert torch.allclose(ligand_mask, g.nodes['ligand_atom'].data['mask'])
protein_mask = torch.zeros(max_num_protein_atoms, 1)
protein_mask[:pocket_mol.GetNumAtoms(), :] = 1.
assert torch.allclose(protein_mask, g.nodes['protein_atom'].data['mask'])
# Test neighbor_cutoff
neighbor_cutoff = 6.
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords,
neighbor_cutoff=neighbor_cutoff)
assert g.number_of_edges('protein') == 3405
assert g.number_of_edges('ligand') == 193
assert g.number_of_edges(('protein_atom', 'complex', 'protein_atom')) == 3331
assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 123
assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 119
assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 82
# Test max_num_neighbors
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords,
max_num_neighbors=6)
assert g.number_of_edges('protein') == 1716
assert g.number_of_edges('ligand') == 126
assert g.number_of_edges(('protein_atom', 'complex', 'protein_atom')) == 1691
assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 86
assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 40
assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 25
# Test strip_hydrogens
g = ACNN_graph_construction_and_featurization(pocket_mol_with_h,
pocket_mol_with_h,
pocket_coords_with_h,
pocket_coords_with_h,
strip_hydrogens=True)
assert g.number_of_nodes('ligand_atom') != pocket_mol_with_h.GetNumAtoms()
assert g.number_of_nodes('protein_atom') != pocket_mol_with_h.GetNumAtoms()
non_h_atomic_numbers = []
for i in range(pocket_mol_with_h.GetNumAtoms()):
atom = pocket_mol_with_h.GetAtomWithIdx(i)
if atom.GetSymbol() != 'H':
non_h_atomic_numbers.append(atom.GetAtomicNum())
non_h_atomic_numbers = torch.tensor(non_h_atomic_numbers).float().reshape(-1, 1)
assert torch.allclose(non_h_atomic_numbers, g.nodes['ligand_atom'].data['atomic_number'])
assert torch.allclose(non_h_atomic_numbers, g.nodes['protein_atom'].data['atomic_number'])
if __name__ == '__main__':
test_acnn_graph_construction_and_featurization()
import os
import torch
import torch.nn as nn
from dgllife.utils import EarlyStopping
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def test_early_stopping_high():
model1 = nn.Linear(2, 3)
stopper = EarlyStopping(mode='higher',
patience=1,
filename='test.pkl')
# Save model in the first step
stopper.step(1., model1)
model1.weight.data = model1.weight.data + 1
model2 = nn.Linear(2, 3)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
# Save model checkpoint with performance improvement
model1.weight.data = model1.weight.data + 1
stopper.step(2., model1)
stopper.load_checkpoint(model2)
assert torch.allclose(model1.weight, model2.weight)
# Stop when no improvement observed
model1.weight.data = model1.weight.data + 1
assert stopper.step(0.5, model1)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
remove_file('test.pkl')
def test_early_stopping_low():
model1 = nn.Linear(2, 3)
stopper = EarlyStopping(mode='lower',
patience=1,
filename='test.pkl')
# Save model in the first step
stopper.step(1., model1)
model1.weight.data = model1.weight.data + 1
model2 = nn.Linear(2, 3)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
# Save model checkpoint with performance improvement
model1.weight.data = model1.weight.data + 1
stopper.step(0.5, model1)
stopper.load_checkpoint(model2)
assert torch.allclose(model1.weight, model2.weight)
# Stop when no improvement observed
model1.weight.data = model1.weight.data + 1
assert stopper.step(2, model1)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
remove_file('test.pkl')
if __name__ == '__main__':
test_early_stopping_high()
test_early_stopping_low()
import numpy as np
import torch
from dgllife.utils.eval import *
def test_Meter():
label = torch.tensor([[0., 1.],
[0., 1.],
[1., 0.]])
pred = torch.tensor([[0.5, 0.5],
[0., 1.],
[1., 0.]])
mask = torch.tensor([[1., 0.],
[0., 1.],
[1., 1.]])
label_mean, label_std = label.mean(dim=0), label.std(dim=0)
# pearson r2
meter = Meter(label_mean, label_std)
meter.update(pred, label)
true_scores = [0.7500000774286983, 0.7500000516191412]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
assert meter.pearson_r2('sum') == np.sum(true_scores)
assert meter.compute_metric('r2') == true_scores
assert meter.compute_metric('r2', 'mean') == np.mean(true_scores)
assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)
meter = Meter(label_mean, label_std)
meter.update(pred, label, mask)
true_scores = [1.0, 1.0]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
assert meter.pearson_r2('sum') == np.sum(true_scores)
assert meter.compute_metric('r2') == true_scores
assert meter.compute_metric('r2', 'mean') == np.mean(true_scores)
assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)
# mae
meter = Meter()
meter.update(pred, label)
true_scores = [0.1666666716337204, 0.1666666716337204]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
assert meter.mae('sum') == np.sum(true_scores)
assert meter.compute_metric('mae') == true_scores
assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
meter = Meter()
meter.update(pred, label, mask)
true_scores = [0.25, 0.0]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
assert meter.mae('sum') == np.sum(true_scores)
assert meter.compute_metric('mae') == true_scores
assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
# rmsef
meter = Meter(label_mean, label_std)
meter.update(pred, label)
true_scores = [0.41068359261794546, 0.4106836107598449]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
meter = Meter(label_mean, label_std)
meter.update(pred, label, mask)
true_scores = [0.44433766459035057, 0.5019903799993205]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
# roc auc score
meter = Meter()
meter.update(pred, label)
true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
assert meter.compute_metric('roc_auc_score') == true_scores
assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
meter = Meter()
meter.update(pred, label, mask)
true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
assert meter.compute_metric('roc_auc_score') == true_scores
assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
def test_cases_with_undefined_scores():
label = torch.tensor([[0., 1.],
[0., 1.],
[1., 1.]])
pred = torch.tensor([[0.5, 0.5],
[0., 1.],
[1., 0.]])
meter = Meter()
meter.update(pred, label)
true_scores = [1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
if __name__ == '__main__':
test_Meter()
test_cases_with_undefined_scores()
import torch
from dgllife.utils.featurizers import *
from rdkit import Chem
def test_one_hot_encoding():
x = 1.
allowable_set = [0., 1., 2.]
assert one_hot_encoding(x, allowable_set) == [0, 1, 0]
assert one_hot_encoding(x, allowable_set, encode_unknown=True) == [0, 1, 0, 0]
assert one_hot_encoding(x, allowable_set) == [0, 1, 0, 0]
assert one_hot_encoding(x, allowable_set, encode_unknown=True) == [0, 1, 0, 0]
assert one_hot_encoding(-1, allowable_set, encode_unknown=True) == [0, 0, 0, 1]
def test_mol1():
return Chem.MolFromSmiles('CCO')
def test_mol2():
return Chem.MolFromSmiles('C1=CC2=CC=CC=CC2=C1')
def test_mol3():
return Chem.MolFromSmiles('O=C(O)/C=C/C(=O)O')
def test_atom_type_one_hot():
mol = test_mol1()
assert atom_type_one_hot(mol.GetAtomWithIdx(0), ['C', 'O']) == [1, 0]
assert atom_type_one_hot(mol.GetAtomWithIdx(2), ['C', 'O']) == [0, 1]
def test_atomic_number_one_hot():
mol = test_mol1()
assert atomic_number_one_hot(mol.GetAtomWithIdx(0), [6, 8]) == [1, 0]
assert atomic_number_one_hot(mol.GetAtomWithIdx(2), [6, 8]) == [0, 1]
def test_atomic_number():
mol = test_mol1()
assert atomic_number(mol.GetAtomWithIdx(0)) == [6]
assert atomic_number(mol.GetAtomWithIdx(2)) == [8]
def test_atom_degree_one_hot():
mol = test_mol1()
assert atom_degree_one_hot(mol.GetAtomWithIdx(0), [0, 1, 2]) == [0, 1, 0]
assert atom_degree_one_hot(mol.GetAtomWithIdx(1), [0, 1, 2]) == [0, 0, 1]
def test_atom_degree():
mol = test_mol1()
assert atom_degree(mol.GetAtomWithIdx(0)) == [1]
assert atom_degree(mol.GetAtomWithIdx(1)) == [2]
def test_atom_total_degree_one_hot():
mol = test_mol1()
assert atom_total_degree_one_hot(mol.GetAtomWithIdx(0), [0, 2, 4]) == [0, 0, 1]
assert atom_total_degree_one_hot(mol.GetAtomWithIdx(2), [0, 2, 4]) == [0, 1, 0]
def test_atom_total_degree():
mol = test_mol1()
assert atom_total_degree(mol.GetAtomWithIdx(0)) == [4]
assert atom_total_degree(mol.GetAtomWithIdx(2)) == [2]
def test_atom_explicit_valence_one_hot():
mol = test_mol1()
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(0), [1, 2, 3]) == [1, 0, 0]
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(1), [1, 2, 3]) == [0, 1, 0]
def test_atom_explicit_valence():
mol = test_mol1()
assert atom_explicit_valence(mol.GetAtomWithIdx(0)) == [1]
assert atom_explicit_valence(mol.GetAtomWithIdx(1)) == [2]
def test_atom_implicit_valence_one_hot():
mol = test_mol1()
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(0), [1, 2, 3]) == [0, 0, 1]
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(1), [1, 2, 3]) == [0, 1, 0]
def test_atom_implicit_valence():
mol = test_mol1()
assert atom_implicit_valence(mol.GetAtomWithIdx(0)) == [3]
assert atom_implicit_valence(mol.GetAtomWithIdx(1)) == [2]
def test_atom_hybridization_one_hot():
mol = test_mol1()
assert atom_hybridization_one_hot(mol.GetAtomWithIdx(0)) == [0, 0, 1, 0, 0]
def test_atom_total_num_H_one_hot():
mol = test_mol1()
assert atom_total_num_H_one_hot(mol.GetAtomWithIdx(0)) == [0, 0, 0, 1, 0]
assert atom_total_num_H_one_hot(mol.GetAtomWithIdx(1)) == [0, 0, 1, 0, 0]
def test_atom_total_num_H():
mol = test_mol1()
assert atom_total_num_H(mol.GetAtomWithIdx(0)) == [3]
assert atom_total_num_H(mol.GetAtomWithIdx(1)) == [2]
def test_atom_formal_charge_one_hot():
mol = test_mol1()
assert atom_formal_charge_one_hot(mol.GetAtomWithIdx(0)) == [0, 0, 1, 0, 0]
def test_atom_formal_charge():
mol = test_mol1()
assert atom_formal_charge(mol.GetAtomWithIdx(0)) == [0]
def test_atom_num_radical_electrons_one_hot():
mol = test_mol1()
assert atom_num_radical_electrons_one_hot(mol.GetAtomWithIdx(0)) == [1, 0, 0, 0, 0]
def test_atom_num_radical_electrons():
mol = test_mol1()
assert atom_num_radical_electrons(mol.GetAtomWithIdx(0)) == [0]
def test_atom_is_aromatic_one_hot():
mol = test_mol1()
assert atom_is_aromatic_one_hot(mol.GetAtomWithIdx(0)) == [1, 0]
mol = test_mol2()
assert atom_is_aromatic_one_hot(mol.GetAtomWithIdx(0)) == [0, 1]
def test_atom_is_aromatic():
mol = test_mol1()
assert atom_is_aromatic(mol.GetAtomWithIdx(0)) == [0]
mol = test_mol2()
assert atom_is_aromatic(mol.GetAtomWithIdx(0)) == [1]
def test_atom_is_in_ring_one_hot():
mol = test_mol1()
assert atom_is_in_ring_one_hot(mol.GetAtomWithIdx(0)) == [1, 0]
mol = test_mol2()
assert atom_is_in_ring_one_hot(mol.GetAtomWithIdx(0)) == [0, 1]
def test_atom_is_in_ring():
mol = test_mol1()
assert atom_is_in_ring(mol.GetAtomWithIdx(0)) == [0]
mol = test_mol2()
assert atom_is_in_ring(mol.GetAtomWithIdx(0)) == [1]
def test_atom_chiral_tag_one_hot():
mol = test_mol1()
assert atom_chiral_tag_one_hot(mol.GetAtomWithIdx(0)) == [1, 0, 0, 0]
def test_atom_mass():
mol = test_mol1()
atom = mol.GetAtomWithIdx(0)
assert atom_mass(atom) == [atom.GetMass() * 0.01]
atom = mol.GetAtomWithIdx(1)
assert atom_mass(atom) == [atom.GetMass() * 0.01]
def test_concat_featurizer():
test_featurizer = ConcatFeaturizer(
[atom_is_aromatic_one_hot, atom_chiral_tag_one_hot]
)
mol = test_mol1()
assert test_featurizer(mol.GetAtomWithIdx(0)) == [1, 0, 1, 0, 0, 0]
mol = test_mol2()
assert test_featurizer(mol.GetAtomWithIdx(0)) == [0, 1, 1, 0, 0, 0]
class TestAtomFeaturizer(BaseAtomFeaturizer):
def __init__(self):
super(TestAtomFeaturizer, self).__init__(
featurizer_funcs={
'h1': ConcatFeaturizer([atom_total_degree_one_hot,
atom_formal_charge_one_hot]),
'h2': ConcatFeaturizer([atom_num_radical_electrons_one_hot])
}
)
def test_base_atom_featurizer():
test_featurizer = TestAtomFeaturizer()
assert test_featurizer.feat_size('h1') == 11
assert test_featurizer.feat_size('h2') == 5
mol = test_mol1()
feats = test_featurizer(mol)
assert torch.allclose(feats['h1'],
torch.tensor([[0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.]]))
assert torch.allclose(feats['h2'],
torch.tensor([[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.]]))
def test_canonical_atom_featurizer():
test_featurizer = CanonicalAtomFeaturizer()
assert test_featurizer.feat_size() == 74
assert test_featurizer.feat_size('h') == 74
mol = test_mol1()
feats = test_featurizer(mol)
assert list(feats.keys()) == ['h']
assert torch.allclose(feats['h'],
torch.tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]]))
def test_weave_atom_featurizer():
featurizer = WeaveAtomFeaturizer()
assert featurizer.feat_size() == 27
mol = test_mol1()
feats = featurizer(mol)
assert list(feats.keys()) == ['h']
assert torch.allclose(feats['h'],
torch.tensor([[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, -0.0418, 0.0000, 0.0000, 0.0000,
1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0402, 0.0000, 0.0000, 0.0000,
1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, -0.3967, 0.0000, 0.0000, 0.0000,
1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000]]), rtol=1e-3)
def test_pretrain_atom_featurizer():
featurizer = PretrainAtomFeaturizer()
mol = test_mol1()
feats = featurizer(mol)
assert list(feats.keys()) == ['atomic_number', 'chirality_type']
assert torch.allclose(feats['atomic_number'], torch.tensor([[5, 5, 7]]))
assert torch.allclose(feats['chirality_type'], torch.tensor([[0, 0, 0]]))
def test_bond_type_one_hot():
mol = test_mol1()
assert bond_type_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0]
mol = test_mol2()
assert bond_type_one_hot(mol.GetBondWithIdx(0)) == [0, 0, 0, 1]
def test_bond_is_conjugated_one_hot():
mol = test_mol1()
assert bond_is_conjugated_one_hot(mol.GetBondWithIdx(0)) == [1, 0]
mol = test_mol2()
assert bond_is_conjugated_one_hot(mol.GetBondWithIdx(0)) == [0, 1]
def test_bond_is_conjugated():
mol = test_mol1()
assert bond_is_conjugated(mol.GetBondWithIdx(0)) == [0]
mol = test_mol2()
assert bond_is_conjugated(mol.GetBondWithIdx(0)) == [1]
def test_bond_is_in_ring_one_hot():
mol = test_mol1()
assert bond_is_in_ring_one_hot(mol.GetBondWithIdx(0)) == [1, 0]
mol = test_mol2()
assert bond_is_in_ring_one_hot(mol.GetBondWithIdx(0)) == [0, 1]
def test_bond_is_in_ring():
mol = test_mol1()
assert bond_is_in_ring(mol.GetBondWithIdx(0)) == [0]
mol = test_mol2()
assert bond_is_in_ring(mol.GetBondWithIdx(0)) == [1]
def test_bond_stereo_one_hot():
mol = test_mol1()
assert bond_stereo_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0, 0, 0]
def test_bond_direction_one_hot():
mol = test_mol3()
assert bond_direction_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0]
assert bond_direction_one_hot(mol.GetBondWithIdx(2)) == [0, 1, 0]
class TestBondFeaturizer(BaseBondFeaturizer):
def __init__(self):
super(TestBondFeaturizer, self).__init__(
featurizer_funcs={
'h1': ConcatFeaturizer([bond_is_in_ring, bond_is_conjugated]),
'h2': ConcatFeaturizer([bond_stereo_one_hot])
}
)
def test_base_bond_featurizer():
test_featurizer = TestBondFeaturizer()
assert test_featurizer.feat_size('h1') == 2
assert test_featurizer.feat_size('h2') == 6
mol = test_mol1()
feats = test_featurizer(mol)
assert torch.allclose(feats['h1'], torch.tensor([[0., 0.], [0., 0.], [0., 0.], [0., 0.]]))
assert torch.allclose(feats['h2'], torch.tensor([[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.]]))
def test_canonical_bond_featurizer():
test_featurizer = CanonicalBondFeaturizer()
assert test_featurizer.feat_size() == 12
assert test_featurizer.feat_size('e') == 12
mol = test_mol1()
feats = test_featurizer(mol)
assert torch.allclose(feats['e'], torch.tensor(
[[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]))
def test_weave_edge_featurizer():
test_featurizer = WeaveEdgeFeaturizer()
assert test_featurizer.feat_size() == 12
mol = test_mol1()
feats = test_featurizer(mol)
assert torch.allclose(feats['e'],
torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))
def test_pretrain_bond_featurizer():
mol = test_mol3()
test_featurizer = PretrainBondFeaturizer()
feats = test_featurizer(mol)
assert torch.allclose(feats['bond_type'].nonzero(),
torch.tensor([[0], [1], [6], [7], [10], [11], [14], [15],
[16], [17], [18], [19], [20], [21]]))
assert torch.allclose(feats['bond_direction_type'].nonzero(),
torch.tensor([[4], [5], [8], [9]]))
test_featurizer = PretrainBondFeaturizer(self_loop=False)
feats = test_featurizer(mol)
assert torch.allclose(feats['bond_type'].nonzero(),
torch.tensor([[0], [1], [6], [7], [10], [11]]))
assert torch.allclose(feats['bond_direction_type'].nonzero(),
torch.tensor([[4], [5], [8], [9]]))
if __name__ == '__main__':
test_one_hot_encoding()
test_atom_type_one_hot()
test_atomic_number_one_hot()
test_atomic_number()
test_atom_degree_one_hot()
test_atom_degree()
test_atom_total_degree_one_hot()
test_atom_total_degree()
test_atom_explicit_valence()
test_atom_implicit_valence_one_hot()
test_atom_implicit_valence()
test_atom_hybridization_one_hot()
test_atom_total_num_H_one_hot()
test_atom_total_num_H()
test_atom_formal_charge_one_hot()
test_atom_formal_charge()
test_atom_num_radical_electrons_one_hot()
test_atom_num_radical_electrons()
test_atom_is_aromatic_one_hot()
test_atom_is_aromatic()
test_atom_is_in_ring_one_hot()
test_atom_is_in_ring()
test_atom_chiral_tag_one_hot()
test_atom_mass()
test_concat_featurizer()
test_base_atom_featurizer()
test_canonical_atom_featurizer()
test_weave_atom_featurizer()
test_pretrain_atom_featurizer()
test_bond_type_one_hot()
test_bond_is_conjugated_one_hot()
test_bond_is_conjugated()
test_bond_is_in_ring_one_hot()
test_bond_is_in_ring()
test_bond_stereo_one_hot()
test_bond_direction_one_hot()
test_base_bond_featurizer()
test_canonical_bond_featurizer()
test_weave_edge_featurizer()
test_pretrain_bond_featurizer()
import numpy as np
import torch
from dgllife.utils.featurizers import *
from dgllife.utils.mol_to_graph import *
from rdkit import Chem
from rdkit.Chem import AllChem
test_smiles1 = 'CCO'
test_smiles2 = 'Fc1ccccc1'
test_smiles3 = '[CH2:1]([CH3:2])[N:3]1[CH2:4][CH2:5][C:6]([CH3:16])' \
'([CH3:17])[c:7]2[cH:8][cH:9][c:10]([N+:13]([O-:14])=[O:15])' \
'[cH:11][c:12]21.[CH3:18][CH2:19][O:20][C:21]([CH3:22])=[O:23]'
class TestAtomFeaturizer(BaseAtomFeaturizer):
def __init__(self):
super(TestAtomFeaturizer, self).__init__(
featurizer_funcs={'hv': ConcatFeaturizer([atomic_number])})
class TestBondFeaturizer(BaseBondFeaturizer):
def __init__(self):
super(TestBondFeaturizer, self).__init__(
featurizer_funcs={'he': ConcatFeaturizer([bond_is_in_ring])})
def test_smiles_to_bigraph():
# Test the case with self loops added.
g1 = smiles_to_bigraph(test_smiles1, add_self_loop=True)
src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 2, 2, 1, 0, 1, 2]))
assert torch.allclose(dst, torch.LongTensor([2, 0, 1, 2, 0, 1, 2]))
# Test the case without self loops.
test_node_featurizer = TestAtomFeaturizer()
test_edge_featurizer = TestBondFeaturizer()
g2 = smiles_to_bigraph(test_smiles2, add_self_loop=False,
node_featurizer=test_node_featurizer,
edge_featurizer=test_edge_featurizer)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[9.], [6.], [6.], [6.],
[6.], [6.], [6.]]))
assert torch.allclose(g2.edata['he'], torch.tensor([[0.], [0.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
g3 = smiles_to_bigraph(test_smiles3, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g3.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_mol_to_bigraph():
mol1 = Chem.MolFromSmiles(test_smiles1)
g1 = mol_to_bigraph(mol1, add_self_loop=True)
src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 2, 2, 1, 0, 1, 2]))
assert torch.allclose(dst, torch.LongTensor([2, 0, 1, 2, 0, 1, 2]))
# Test the case without self loops.
mol2 = Chem.MolFromSmiles(test_smiles2)
test_node_featurizer = TestAtomFeaturizer()
test_edge_featurizer = TestBondFeaturizer()
g2 = mol_to_bigraph(mol2, add_self_loop=False,
node_featurizer=test_node_featurizer,
edge_featurizer=test_edge_featurizer)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[9.], [6.], [6.], [6.],
[6.], [6.], [6.]]))
assert torch.allclose(g2.edata['he'], torch.tensor([[0.], [0.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
mol3 = Chem.MolFromSmiles(test_smiles3)
g3 = mol_to_bigraph(mol3, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g3.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_smiles_to_complete_graph():
test_node_featurizer = TestAtomFeaturizer()
g1 = smiles_to_complete_graph(test_smiles1, add_self_loop=False,
node_featurizer=test_node_featurizer)
src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2]))
assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1]))
assert torch.allclose(g1.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
g2 = smiles_to_complete_graph(test_smiles3, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_mol_to_complete_graph():
test_node_featurizer = TestAtomFeaturizer()
mol1 = Chem.MolFromSmiles(test_smiles1)
g1 = mol_to_complete_graph(mol1, add_self_loop=False,
node_featurizer=test_node_featurizer)
src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2]))
assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1]))
assert torch.allclose(g1.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
mol2 = Chem.MolFromSmiles(test_smiles3)
g2 = mol_to_complete_graph(mol2, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_k_nearest_neighbors():
coordinates = np.array([[0.1, 0.1, 0.1],
[0.2, 0.1, 0.1],
[0.15, 0.15, 0.1],
[0.1, 0.15, 0.16],
[1.2, 0.1, 0.1],
[1.3, 0.2, 0.1]])
neighbor_cutoff = 1.
max_num_neighbors = 2
srcs, dsts, dists = k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors)
assert srcs == [2, 3, 2, 0, 0, 1, 0, 2, 1, 5, 4]
assert dsts == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5]
assert dists == [0.07071067811865478, 0.0781024967590666, 0.07071067811865483,
0.1, 0.07071067811865478, 0.07071067811865483, 0.0781024967590666,
0.0781024967590666, 1.0, 0.14142135623730956, 0.14142135623730956]
# Test the case where self loops are included
srcs, dsts, dists = k_nearest_neighbors(coordinates, neighbor_cutoff,
max_num_neighbors, self_loops=True)
assert srcs == [0, 2, 1, 2, 2, 0, 3, 0, 4, 5, 4, 5]
assert dsts == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]
assert dists == [0.0, 0.07071067811865478, 0.0, 0.07071067811865483, 0.0,
0.07071067811865478, 0.0, 0.0781024967590666, 0.0,
0.14142135623730956, 0.14142135623730956, 0.0]
# Test the case where max_num_neighbors is not given
srcs, dsts, dists = k_nearest_neighbors(coordinates, neighbor_cutoff=10.)
assert srcs == [1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 1, 3, 4, 5,
0, 1, 2, 4, 5, 0, 1, 2, 3, 5, 0, 1, 2, 3, 4]
assert dsts == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5]
assert dists == [0.1, 0.07071067811865478, 0.0781024967590666, 1.1,
1.2041594578792296, 0.1, 0.07071067811865483,
0.12688577540449525, 1.0, 1.104536101718726,
0.07071067811865478, 0.07071067811865483,
0.0781024967590666, 1.0511898020814319, 1.151086443322134,
0.0781024967590666, 0.12688577540449525, 0.0781024967590666,
1.1027692415006867, 1.202538980657176, 1.1, 1.0,
1.0511898020814319, 1.1027692415006867, 0.14142135623730956,
1.2041594578792296, 1.104536101718726, 1.151086443322134,
1.202538980657176, 0.14142135623730956]
def test_smiles_to_nearest_neighbor_graph():
mol = Chem.MolFromSmiles(test_smiles1)
AllChem.EmbedMolecule(mol)
coordinates = mol.GetConformers()[0].GetPositions()
# Test node featurizer
test_node_featurizer = TestAtomFeaturizer()
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates, neighbor_cutoff=10,
node_featurizer=test_node_featurizer)
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
assert g.number_of_edges() == 6
assert 'dist' not in g.edata
# Test self loops
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates, neighbor_cutoff=10,
add_self_loop=True)
assert g.number_of_edges() == 9
# Test max_num_neighbors
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates, neighbor_cutoff=10,
max_num_neighbors=1, add_self_loop=True)
assert g.number_of_edges() == 3
# Test pairwise distances
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates,
neighbor_cutoff=10, keep_dists=True)
assert 'dist' in g.edata
coordinates = torch.from_numpy(coordinates)
srcs, dsts = g.edges()
dist = torch.norm(
coordinates[srcs] - coordinates[dsts], dim=1, p=2).float().reshape(-1, 1)
assert torch.allclose(dist, g.edata['dist'])
def test_mol_to_nearest_neighbor_graph():
mol = Chem.MolFromSmiles(test_smiles1)
AllChem.EmbedMolecule(mol)
coordinates = mol.GetConformers()[0].GetPositions()
# Test node featurizer
test_node_featurizer = TestAtomFeaturizer()
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10,
node_featurizer=test_node_featurizer)
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
assert g.number_of_edges() == 6
assert 'dist' not in g.edata
# Test self loops
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10, add_self_loop=True)
assert g.number_of_edges() == 9
# Test max_num_neighbors
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10,
max_num_neighbors=1, add_self_loop=True)
assert g.number_of_edges() == 3
# Test pairwise distances
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10, keep_dists=True)
assert 'dist' in g.edata
coordinates = torch.from_numpy(coordinates)
srcs, dsts = g.edges()
dist = torch.norm(
coordinates[srcs] - coordinates[dsts], dim=1, p=2).float().reshape(-1, 1)
assert torch.allclose(dist, g.edata['dist'])
if __name__ == '__main__':
test_smiles_to_bigraph()
test_mol_to_bigraph()
test_smiles_to_complete_graph()
test_mol_to_complete_graph()
test_k_nearest_neighbors()
test_smiles_to_nearest_neighbor_graph()
test_mol_to_nearest_neighbor_graph()
import numpy as np
import os
import shutil
from dgl.data.utils import download, _get_dgl_url, extract_archive
from dgllife.utils.rdkit_utils import get_mol_3d_coordinates, load_molecule
from rdkit import Chem
from rdkit.Chem import AllChem
def test_get_mol_3D_coordinates():
mol = Chem.MolFromSmiles('CCO')
# Test the case when conformation does not exist
assert get_mol_3d_coordinates(mol) is None
# Test the case when conformation exists
AllChem.EmbedMolecule(mol)
AllChem.MMFFOptimizeMolecule(mol)
coords = get_mol_3d_coordinates(mol)
assert isinstance(coords, np.ndarray)
assert coords.shape == (mol.GetNumAtoms(), 3)
def remove_dir(dir):
if os.path.isdir(dir):
try:
shutil.rmtree(dir)
except OSError:
pass
def test_load_molecule():
remove_dir('tmp1')
remove_dir('tmp2')
url = _get_dgl_url('dgllife/example_mols.tar.gz')
local_path = 'tmp1/example_mols.tar.gz'
download(url, path=local_path)
extract_archive(local_path, 'tmp2')
load_molecule('tmp2/example_mols/example.sdf')
load_molecule('tmp2/example_mols/example.mol2', use_conformation=False, sanitize=True)
load_molecule('tmp2/example_mols/example.pdbqt', calc_charges=True)
mol, _ = load_molecule('tmp2/example_mols/example.pdb', remove_hs=True)
assert mol.GetNumAtoms() == mol.GetNumHeavyAtoms()
remove_dir('tmp1')
remove_dir('tmp2')
if __name__ == '__main__':
test_get_mol_3D_coordinates()
test_load_molecule()
import torch
from dgllife.utils.splitters import *
from rdkit import Chem
class TestDataset(object):
def __init__(self):
self.smiles = [
'CCO',
'C1CCCCC1',
'O1CCOCC1',
'C1CCCC2C1CCCC2',
'N#N'
]
self.mols = [Chem.MolFromSmiles(s) for s in self.smiles]
self.labels = torch.arange(2 * len(self.smiles)).reshape(len(self.smiles), -1)
def __getitem__(self, item):
return self.smiles[item], self.mols[item]
def __len__(self):
return len(self.smiles)
def test_consecutive_splitter(dataset):
ConsecutiveSplitter.train_val_test_split(dataset)
ConsecutiveSplitter.k_fold_split(dataset)
def test_random_splitter(dataset):
RandomSplitter.train_val_test_split(dataset, random_state=0)
RandomSplitter.k_fold_split(dataset)
def test_molecular_weight_splitter(dataset):
MolecularWeightSplitter.train_val_test_split(dataset)
MolecularWeightSplitter.k_fold_split(dataset, mols=dataset.mols)
def test_scaffold_splitter(dataset):
ScaffoldSplitter.train_val_test_split(dataset, include_chirality=True)
ScaffoldSplitter.k_fold_split(dataset, mols=dataset.mols)
def test_single_task_stratified_splitter(dataset):
SingleTaskStratifiedSplitter.train_val_test_split(dataset, dataset.labels, 1)
SingleTaskStratifiedSplitter.k_fold_split(dataset, dataset.labels, 1)
if __name__ == '__main__':
dataset = TestDataset()
test_consecutive_splitter(dataset)
test_random_splitter(dataset)
test_molecular_weight_splitter(dataset)
test_scaffold_splitter(dataset)
test_single_task_stratified_splitter(dataset)
......@@ -129,160 +129,3 @@ Protein-Protein Interaction dataset
.. autoclass:: PPIDataset
:members: __getitem__, __len__
Molecular Graphs
----------------
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
Data Loading and Processing Utils
`````````````````````````````````
We adapt several utilities for processing molecules from
`DeepChem <https://github.com/deepchem/deepchem/blob/master/deepchem>`__.
.. autosummary::
:toctree: ../../generated/
chem.add_hydrogens_to_mol
chem.get_mol_3D_coordinates
chem.load_molecule
chem.multiprocess_load_molecules
Featurization Utils for Single Molecule
```````````````````````````````````````
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds).
General utils:
.. autosummary::
:toctree: ../../generated/
chem.one_hot_encoding
chem.ConcatFeaturizer
chem.ConcatFeaturizer.__call__
Utils for atom featurization:
.. autosummary::
:toctree: ../../generated/
chem.atom_type_one_hot
chem.atomic_number_one_hot
chem.atomic_number
chem.atom_degree_one_hot
chem.atom_degree
chem.atom_total_degree_one_hot
chem.atom_total_degree
chem.atom_implicit_valence_one_hot
chem.atom_implicit_valence
chem.atom_hybridization_one_hot
chem.atom_total_num_H_one_hot
chem.atom_total_num_H
chem.atom_formal_charge_one_hot
chem.atom_formal_charge
chem.atom_num_radical_electrons_one_hot
chem.atom_num_radical_electrons
chem.atom_is_aromatic_one_hot
chem.atom_is_aromatic
chem.atom_chiral_tag_one_hot
chem.atom_mass
chem.BaseAtomFeaturizer
chem.BaseAtomFeaturizer.feat_size
chem.BaseAtomFeaturizer.__call__
chem.CanonicalAtomFeaturizer
Utils for bond featurization:
.. autosummary::
:toctree: ../../generated/
chem.bond_type_one_hot
chem.bond_is_conjugated_one_hot
chem.bond_is_conjugated
chem.bond_is_in_ring_one_hot
chem.bond_is_in_ring
chem.bond_stereo_one_hot
chem.BaseBondFeaturizer
chem.BaseBondFeaturizer.feat_size
chem.BaseBondFeaturizer.__call__
chem.CanonicalBondFeaturizer
Graph Construction for Single Molecule
``````````````````````````````````````
Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects are listed below:
.. autosummary::
:toctree: ../../generated/
chem.mol_to_graph
chem.smiles_to_bigraph
chem.mol_to_bigraph
chem.smiles_to_complete_graph
chem.mol_to_complete_graph
chem.k_nearest_neighbors
Graph Construction and Featurization for Ligand-Protein Complex
```````````````````````````````````````````````````````````````
Constructing DGLHeteroGraphs and featurize for them.
.. autosummary::
:toctree: ../../generated/
chem.ACNN_graph_construction_and_featurization
Dataset Classes
```````````````
If your dataset is stored in a ``.csv`` file, you may find it helpful to use
.. autoclass:: dgl.data.chem.MoleculeCSVDataset
:members: __getitem__, __len__
Currently four datasets are supported:
* Tox21
* TencentAlchemyDataset
* PubChemBioAssayAromaticity
* PDBBind
.. autoclass:: dgl.data.chem.Tox21
:members: __getitem__, __len__, task_pos_weights
.. autoclass:: dgl.data.chem.TencentAlchemyDataset
:members: __getitem__, __len__, set_mean_and_std
.. autoclass:: dgl.data.chem.PubChemBioAssayAromaticity
:members: __getitem__, __len__
.. autoclass:: dgl.data.chem.PDBBind
:members: __getitem__, __len__
Dataset Splitting
`````````````````
We provide support for some common data splitting methods:
* consecutive split
* random split
* molecular weight split
* Bemis-Murcko scaffold split
* single-task-stratified split
.. autoclass:: dgl.data.chem.ConsecutiveSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.RandomSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.MolecularWeightSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.ScaffoldSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.SingleTaskStratifiedSplitter
:members: train_val_test_split, k_fold_split
.. _apimodelzoo:
dgl.model_zoo
==============
.. currentmodule:: dgl.model_zoo
Chemistry
---------
Utils
`````
.. autosummary::
:toctree: ../../generated/
chem.load_pretrained
Property Prediction
```````````````````
Currently supported model architectures:
* GCNClassifier
* GATClassifier
* MPNN
* SchNet
* MGCN
* AttentiveFP
.. autoclass:: dgl.model_zoo.chem.GCNClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.GATClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.MPNNModel
:members: forward
.. autoclass:: dgl.model_zoo.chem.SchNet
:members: forward
.. autoclass:: dgl.model_zoo.chem.MGCNModel
:members: forward
.. autoclass:: dgl.model_zoo.chem.AttentiveFP
:members: forward
Generative Models
`````````````````
Currently supported model architectures:
* DGMG
* JTNN
.. autoclass:: dgl.model_zoo.chem.DGMG
:members: forward
.. autoclass:: dgl.model_zoo.chem.DGLJTNNVAE
:members: forward
Protein Ligand Binding
``````````````````````
Currently supported model architectures:
* ACNN
.. autoclass:: dgl.model_zoo.chem.ACNN
:members: forward
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment