Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
36c7b771
Unverified
Commit
36c7b771
authored
Jun 05, 2020
by
Mufei Li
Committed by
GitHub
Jun 05, 2020
Browse files
[LifeSci] Move to Independent Repo (#1592)
* Move LifeSci * Remove doc
parent
94c67203
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3004 deletions
+0
-3004
apps/life_sci/tests/lint/pylintrc
apps/life_sci/tests/lint/pylintrc
+0
-500
apps/life_sci/tests/model/test_binding_affinity.py
apps/life_sci/tests/model/test_binding_affinity.py
+0
-69
apps/life_sci/tests/model/test_generative_models.py
apps/life_sci/tests/model/test_generative_models.py
+0
-43
apps/life_sci/tests/model/test_gnn.py
apps/life_sci/tests/model/test_gnn.py
+0
-320
apps/life_sci/tests/model/test_pretrain.py
apps/life_sci/tests/model/test_pretrain.py
+0
-162
apps/life_sci/tests/model/test_property_prediction.py
apps/life_sci/tests/model/test_property_prediction.py
+0
-323
apps/life_sci/tests/model/test_reaction_prediction.py
apps/life_sci/tests/model/test_reaction_prediction.py
+0
-155
apps/life_sci/tests/model/test_readout.py
apps/life_sci/tests/model/test_readout.py
+0
-105
apps/life_sci/tests/scripts/build.sh
apps/life_sci/tests/scripts/build.sh
+0
-28
apps/life_sci/tests/scripts/task_lint.sh
apps/life_sci/tests/scripts/task_lint.sh
+0
-6
apps/life_sci/tests/scripts/task_unit_test.sh
apps/life_sci/tests/scripts/task_unit_test.sh
+0
-35
apps/life_sci/tests/utils/test_complex_to_graph.py
apps/life_sci/tests/utils/test_complex_to_graph.py
+0
-128
apps/life_sci/tests/utils/test_early_stop.py
apps/life_sci/tests/utils/test_early_stop.py
+0
-70
apps/life_sci/tests/utils/test_eval.py
apps/life_sci/tests/utils/test_eval.py
+0
-110
apps/life_sci/tests/utils/test_featurizers.py
apps/life_sci/tests/utils/test_featurizers.py
+0
-385
apps/life_sci/tests/utils/test_mol_to_graph.py
apps/life_sci/tests/utils/test_mol_to_graph.py
+0
-238
apps/life_sci/tests/utils/test_rdkit_utils.py
apps/life_sci/tests/utils/test_rdkit_utils.py
+0
-49
apps/life_sci/tests/utils/test_splitters.py
apps/life_sci/tests/utils/test_splitters.py
+0
-50
docs/source/api/python/data.rst
docs/source/api/python/data.rst
+0
-157
docs/source/api/python/model_zoo.rst
docs/source/api/python/model_zoo.rst
+0
-71
No files found.
apps/life_sci/tests/lint/pylintrc
deleted
100644 → 0
View file @
94c67203
[MASTER]
# Adapted from github.com/dmlc/dgl/tests/lint/pylintrc
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS,_cy2,_cy3,backend,data,contrib
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=4
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=design,
similarities,
no-self-use,
attribute-defined-outside-init,
locally-disabled,
star-args,
pointless-except,
bad-option-value,
global-statement,
fixme,
suppressed-message,
useless-suppression,
locally-enabled,
import-error,
unsubscriptable-object,
unbalanced-tuple-unpacking,
protected-access,
useless-object-inheritance,
no-else-return,
len-as-condition,
cyclic-import, # disabled due to the inevitable dgl.graph -> dgl.subgraph loop
undefined-variable, # disabled due to C extension (should enable)
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,j,k,u,v,e,n,m,w,x,y,g,G,hg,fn,ex,Run,_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
#variable-rgx=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=4000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=dgl.backend,dgl._api_internal
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=yes
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception".
overgeneral-exceptions=Exception
\ No newline at end of file
apps/life_sci/tests/model/test_binding_affinity.py
deleted
100644 → 0
View file @
94c67203
import
dgl
import
os
import
shutil
import
torch
from
dgl.data.utils
import
_get_dgl_url
,
download
,
extract_archive
from
dgllife.model.model_zoo.acnn
import
ACNN
from
dgllife.utils.complex_to_graph
import
ACNN_graph_construction_and_featurization
from
dgllife.utils.rdkit_utils
import
load_molecule
def
remove_dir
(
dir
):
if
os
.
path
.
isdir
(
dir
):
try
:
shutil
.
rmtree
(
dir
)
except
OSError
:
pass
def
test_acnn
():
remove_dir
(
'tmp1'
)
remove_dir
(
'tmp2'
)
url
=
_get_dgl_url
(
'dgllife/example_mols.tar.gz'
)
local_path
=
'tmp1/example_mols.tar.gz'
download
(
url
,
path
=
local_path
)
extract_archive
(
local_path
,
'tmp2'
)
pocket_mol
,
pocket_coords
=
load_molecule
(
'tmp2/example_mols/example.pdb'
,
remove_hs
=
True
)
ligand_mol
,
ligand_coords
=
load_molecule
(
'tmp2/example_mols/example.pdbqt'
,
remove_hs
=
True
)
remove_dir
(
'tmp1'
)
remove_dir
(
'tmp2'
)
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g1
=
ACNN_graph_construction_and_featurization
(
ligand_mol
,
pocket_mol
,
ligand_coords
,
pocket_coords
)
model
=
ACNN
()
model
.
to
(
device
)
g1
.
to
(
device
)
assert
model
(
g1
).
shape
==
torch
.
Size
([
1
,
1
])
bg
=
dgl
.
batch_hetero
([
g1
,
g1
])
bg
.
to
(
device
)
assert
model
(
bg
).
shape
==
torch
.
Size
([
2
,
1
])
model
=
ACNN
(
hidden_sizes
=
[
1
,
2
],
weight_init_stddevs
=
[
1
,
1
],
dropouts
=
[
0.1
,
0.
],
features_to_use
=
torch
.
tensor
([
6.
,
8.
]),
radial
=
[[
12.0
],
[
0.0
,
2.0
],
[
4.0
]])
model
.
to
(
device
)
g1
.
to
(
device
)
assert
model
(
g1
).
shape
==
torch
.
Size
([
1
,
1
])
bg
=
dgl
.
batch_hetero
([
g1
,
g1
])
bg
.
to
(
device
)
assert
model
(
bg
).
shape
==
torch
.
Size
([
2
,
1
])
if
__name__
==
'__main__'
:
test_acnn
()
apps/life_sci/tests/model/test_generative_models.py
deleted
100644 → 0
View file @
94c67203
import
torch
from
rdkit
import
Chem
from
dgllife.model
import
DGMG
,
DGLJTNNVAE
def
test_dgmg
():
model
=
DGMG
(
atom_types
=
[
'O'
,
'Cl'
,
'C'
,
'S'
,
'F'
,
'Br'
,
'N'
],
bond_types
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
],
node_hidden_size
=
1
,
num_prop_rounds
=
1
,
dropout
=
0.2
)
assert
model
(
actions
=
[(
0
,
2
),
(
1
,
3
),
(
0
,
0
),
(
1
,
0
),
(
2
,
0
),
(
1
,
3
),
(
0
,
7
)],
rdkit_mol
=
True
)
==
'CO'
assert
model
(
rdkit_mol
=
False
)
is
None
model
.
eval
()
assert
model
(
rdkit_mol
=
True
)
is
not
None
model
=
DGMG
(
atom_types
=
[
'O'
,
'Cl'
,
'C'
,
'S'
,
'F'
,
'Br'
,
'N'
],
bond_types
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
])
assert
model
(
actions
=
[(
0
,
2
),
(
1
,
3
),
(
0
,
0
),
(
1
,
0
),
(
2
,
0
),
(
1
,
3
),
(
0
,
7
)],
rdkit_mol
=
True
)
==
'CO'
assert
model
(
rdkit_mol
=
False
)
is
None
model
.
eval
()
assert
model
(
rdkit_mol
=
True
)
is
not
None
def
test_jtnn
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
model
=
DGLJTNNVAE
(
hidden_size
=
1
,
latent_size
=
2
,
depth
=
1
).
to
(
device
)
if
__name__
==
'__main__'
:
test_dgmg
()
test_jtnn
()
apps/life_sci/tests/model/test_gnn.py
deleted
100644 → 0
View file @
94c67203
import
dgl
import
torch
import
torch.nn.functional
as
F
from
dgl
import
DGLGraph
from
dgllife.model.gnn
import
*
def
test_graph1
():
"""Graph with node features."""
g
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g
,
torch
.
arange
(
g
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
def
test_graph2
():
"""Batched graph with node features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
arange
(
bg
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
def
test_graph3
():
"""Graph with node and edge features."""
g
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g
,
torch
.
arange
(
g
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
),
\
torch
.
arange
(
2
*
g
.
number_of_edges
()).
float
().
reshape
(
-
1
,
2
)
def
test_graph4
():
"""Batched graph with node and edge features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
arange
(
bg
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
),
\
torch
.
arange
(
2
*
bg
.
number_of_edges
()).
float
().
reshape
(
-
1
,
2
)
def
test_graph5
():
"""Graph with node types and edge distances."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g1
,
torch
.
LongTensor
([
0
,
1
,
0
]),
torch
.
randn
(
3
,
1
)
def
test_graph6
():
"""Batched graph with node types and edge distances."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
LongTensor
([
0
,
1
,
0
,
2
,
0
,
3
,
4
,
4
]),
torch
.
randn
(
7
,
1
)
def
test_graph7
():
"""Graph with categorical node and edge features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g1
,
torch
.
LongTensor
([
0
,
1
,
0
]),
torch
.
LongTensor
([
2
,
3
,
4
]),
\
torch
.
LongTensor
([
0
,
0
,
1
]),
torch
.
LongTensor
([
2
,
3
,
2
])
def
test_graph8
():
"""Batched graph with categorical node and edge features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
LongTensor
([
0
,
1
,
0
,
2
,
1
,
0
,
2
,
2
]),
\
torch
.
LongTensor
([
2
,
3
,
4
,
1
,
0
,
1
,
2
,
2
]),
\
torch
.
LongTensor
([
0
,
0
,
1
,
2
,
1
,
0
,
0
]),
\
torch
.
LongTensor
([
2
,
3
,
2
,
0
,
1
,
2
,
1
])
def
test_gcn
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
# Test default setting
gnn
=
GCN
(
in_feats
=
1
).
to
(
device
)
assert
gnn
(
g
,
node_feats
).
shape
==
torch
.
Size
([
3
,
64
])
assert
gnn
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
8
,
64
])
# Test configured setting
gnn
=
GCN
(
in_feats
=
1
,
hidden_feats
=
[
1
,
1
],
activation
=
[
F
.
relu
,
F
.
relu
],
residual
=
[
True
,
True
],
batchnorm
=
[
True
,
True
],
dropout
=
[
0.2
,
0.2
]).
to
(
device
)
assert
gnn
(
g
,
node_feats
).
shape
==
torch
.
Size
([
3
,
1
])
assert
gnn
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
8
,
1
])
def
test_gat
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
# Test default setting
gnn
=
GAT
(
in_feats
=
1
).
to
(
device
)
assert
gnn
(
g
,
node_feats
).
shape
==
torch
.
Size
([
3
,
32
])
assert
gnn
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
8
,
32
])
# Test configured setting
gnn
=
GAT
(
in_feats
=
1
,
hidden_feats
=
[
1
,
1
],
num_heads
=
[
2
,
3
],
feat_drops
=
[
0.1
,
0.1
],
attn_drops
=
[
0.1
,
0.1
],
alphas
=
[
0.2
,
0.2
],
residuals
=
[
True
,
True
],
agg_modes
=
[
'flatten'
,
'mean'
],
activations
=
[
None
,
F
.
elu
]).
to
(
device
)
assert
gnn
(
g
,
node_feats
).
shape
==
torch
.
Size
([
3
,
1
])
assert
gnn
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
8
,
1
])
gnn
=
GAT
(
in_feats
=
1
,
hidden_feats
=
[
1
,
1
],
num_heads
=
[
2
,
3
],
feat_drops
=
[
0.1
,
0.1
],
attn_drops
=
[
0.1
,
0.1
],
alphas
=
[
0.2
,
0.2
],
residuals
=
[
True
,
True
],
agg_modes
=
[
'mean'
,
'flatten'
],
activations
=
[
None
,
F
.
elu
]).
to
(
device
)
assert
gnn
(
g
,
node_feats
).
shape
==
torch
.
Size
([
3
,
3
])
assert
gnn
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
8
,
3
])
def
test_attentive_fp_gnn
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
,
edge_feats
=
test_graph3
()
g
,
node_feats
,
edge_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
),
edge_feats
.
to
(
device
)
bg
,
batch_node_feats
,
batch_edge_feats
=
test_graph4
()
bg
,
batch_node_feats
,
batch_edge_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
),
\
batch_edge_feats
.
to
(
device
)
# Test AttentiveFPGNN
gnn
=
AttentiveFPGNN
(
node_feat_size
=
1
,
edge_feat_size
=
2
,
num_layers
=
1
,
graph_feat_size
=
1
,
dropout
=
0.
).
to
(
device
)
assert
gnn
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
3
,
1
])
assert
gnn
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
torch
.
Size
([
8
,
1
])
def
test_schnet_gnn
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_types
,
edge_dists
=
test_graph5
()
g
,
node_types
,
edge_dists
=
g
.
to
(
device
),
node_types
.
to
(
device
),
edge_dists
.
to
(
device
)
bg
,
batch_node_types
,
batch_edge_dists
=
test_graph6
()
bg
,
batch_node_types
,
batch_edge_dists
=
bg
.
to
(
device
),
batch_node_types
.
to
(
device
),
\
batch_edge_dists
.
to
(
device
)
# Test default setting
gnn
=
SchNetGNN
().
to
(
device
)
assert
gnn
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
3
,
64
])
assert
gnn
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
torch
.
Size
([
8
,
64
])
# Test configured setting
gnn
=
SchNetGNN
(
num_node_types
=
5
,
node_feats
=
2
,
hidden_feats
=
[
3
],
cutoff
=
0.3
).
to
(
device
)
assert
gnn
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
3
,
2
])
assert
gnn
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
torch
.
Size
([
8
,
2
])
def
test_mgcn_gnn
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_types
,
edge_dists
=
test_graph5
()
g
,
node_types
,
edge_dists
=
g
.
to
(
device
),
node_types
.
to
(
device
),
edge_dists
.
to
(
device
)
bg
,
batch_node_types
,
batch_edge_dists
=
test_graph6
()
bg
,
batch_node_types
,
batch_edge_dists
=
bg
.
to
(
device
),
batch_node_types
.
to
(
device
),
\
batch_edge_dists
.
to
(
device
)
# Test default setting
gnn
=
MGCNGNN
().
to
(
device
)
assert
gnn
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
3
,
512
])
assert
gnn
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
torch
.
Size
([
8
,
512
])
# Test configured setting
gnn
=
MGCNGNN
(
feats
=
2
,
n_layers
=
2
,
num_node_types
=
5
,
num_edge_types
=
150
,
cutoff
=
0.3
).
to
(
device
)
assert
gnn
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
3
,
6
])
assert
gnn
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
torch
.
Size
([
8
,
6
])
def
test_mpnn_gnn
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
,
edge_feats
=
test_graph3
()
g
,
node_feats
,
edge_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
),
edge_feats
.
to
(
device
)
bg
,
batch_node_feats
,
batch_edge_feats
=
test_graph4
()
bg
,
batch_node_feats
,
batch_edge_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
),
\
batch_edge_feats
.
to
(
device
)
# Test default setting
gnn
=
MPNNGNN
(
node_in_feats
=
1
,
edge_in_feats
=
2
).
to
(
device
)
assert
gnn
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
3
,
64
])
assert
gnn
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
torch
.
Size
([
8
,
64
])
# Test configured setting
gnn
=
MPNNGNN
(
node_in_feats
=
1
,
edge_in_feats
=
2
,
node_out_feats
=
2
,
edge_hidden_feats
=
2
,
num_step_message_passing
=
2
).
to
(
device
)
assert
gnn
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
3
,
2
])
assert
gnn
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
torch
.
Size
([
8
,
2
])
def
test_wln
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
,
edge_feats
=
test_graph3
()
g
,
node_feats
,
edge_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
),
edge_feats
.
to
(
device
)
bg
,
batch_node_feats
,
batch_edge_feats
=
test_graph4
()
bg
,
batch_node_feats
,
batch_edge_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
),
\
batch_edge_feats
.
to
(
device
)
# Test default setting
gnn
=
WLN
(
node_in_feats
=
1
,
edge_in_feats
=
2
).
to
(
device
)
assert
gnn
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
3
,
300
])
assert
gnn
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
torch
.
Size
([
8
,
300
])
# Test configured setting
gnn
=
WLN
(
node_in_feats
=
1
,
edge_in_feats
=
2
,
node_out_feats
=
3
,
n_layers
=
1
).
to
(
device
)
assert
gnn
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
3
,
3
])
assert
gnn
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
torch
.
Size
([
8
,
3
])
def
test_weave
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
,
edge_feats
=
test_graph3
()
g
,
node_feats
,
edge_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
),
edge_feats
.
to
(
device
)
bg
,
batch_node_feats
,
batch_edge_feats
=
test_graph4
()
bg
,
batch_node_feats
,
batch_edge_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
),
\
batch_edge_feats
.
to
(
device
)
# Test default setting
gnn
=
WeaveGNN
(
node_in_feats
=
1
,
edge_in_feats
=
2
).
to
(
device
)
assert
gnn
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
3
,
50
])
assert
gnn
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
torch
.
Size
([
8
,
50
])
# Test configured setting
gnn
=
WeaveGNN
(
node_in_feats
=
1
,
edge_in_feats
=
2
,
num_layers
=
1
,
hidden_feats
=
2
).
to
(
device
)
assert
gnn
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
3
,
2
])
assert
gnn
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
torch
.
Size
([
8
,
2
])
def
test_gin
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats1
,
node_feats2
,
edge_feats1
,
edge_feats2
=
test_graph7
()
node_feats1
,
node_feats2
=
node_feats1
.
to
(
device
),
node_feats2
.
to
(
device
)
edge_feats1
,
edge_feats2
=
edge_feats1
.
to
(
device
),
edge_feats2
.
to
(
device
)
bg
,
batch_node_feats1
,
batch_node_feats2
,
\
batch_edge_feats1
,
batch_edge_feats2
=
test_graph8
()
batch_node_feats1
,
batch_node_feats2
=
batch_node_feats1
.
to
(
device
),
\
batch_node_feats2
.
to
(
device
)
batch_edge_feats1
,
batch_edge_feats2
=
batch_edge_feats1
.
to
(
device
),
\
batch_edge_feats2
.
to
(
device
)
# Test default setting
gnn
=
GIN
(
num_node_emb_list
=
[
3
,
5
],
num_edge_emb_list
=
[
3
,
4
]).
to
(
device
)
assert
gnn
(
g
,
[
node_feats1
,
node_feats2
],
[
edge_feats1
,
edge_feats2
]).
shape
\
==
torch
.
Size
([
3
,
300
])
assert
gnn
(
bg
,
[
batch_node_feats1
,
batch_node_feats2
],
[
batch_edge_feats1
,
batch_edge_feats2
]).
shape
==
torch
.
Size
([
8
,
300
])
# Test configured setting
gnn
=
GIN
(
num_node_emb_list
=
[
3
,
5
],
num_edge_emb_list
=
[
3
,
4
],
num_layers
=
2
,
emb_dim
=
10
,
JK
=
'concat'
,
dropout
=
0.1
).
to
(
device
)
assert
gnn
(
g
,
[
node_feats1
,
node_feats2
],
[
edge_feats1
,
edge_feats2
]).
shape
\
==
torch
.
Size
([
3
,
30
])
assert
gnn
(
bg
,
[
batch_node_feats1
,
batch_node_feats2
],
[
batch_edge_feats1
,
batch_edge_feats2
]).
shape
==
torch
.
Size
([
8
,
30
])
if
__name__
==
'__main__'
:
test_gcn
()
test_gat
()
test_attentive_fp_gnn
()
test_schnet_gnn
()
test_mgcn_gnn
()
test_mpnn_gnn
()
test_wln
()
test_weave
()
test_gin
()
apps/life_sci/tests/model/test_pretrain.py
deleted
100644 → 0
View file @
94c67203
import
dgl
import
os
import
torch
from
functools
import
partial
from
dgllife.model
import
load_pretrained
from
dgllife.utils
import
*
def
remove_file
(
fname
):
if
os
.
path
.
isfile
(
fname
):
try
:
os
.
remove
(
fname
)
except
OSError
:
pass
def
run_dgmg_ChEMBL
(
model
):
assert
model
(
actions
=
[(
0
,
2
),
(
1
,
3
),
(
0
,
0
),
(
1
,
0
),
(
2
,
0
),
(
1
,
3
),
(
0
,
7
)],
rdkit_mol
=
True
)
==
'CO'
assert
model
(
rdkit_mol
=
False
)
is
None
model
.
eval
()
assert
model
(
rdkit_mol
=
True
)
is
not
None
def
run_dgmg_ZINC
(
model
):
assert
model
(
actions
=
[(
0
,
2
),
(
1
,
3
),
(
0
,
5
),
(
1
,
0
),
(
2
,
0
),
(
1
,
3
),
(
0
,
9
)],
rdkit_mol
=
True
)
==
'CO'
assert
model
(
rdkit_mol
=
False
)
is
None
model
.
eval
()
assert
model
(
rdkit_mol
=
True
)
is
not
None
def
test_dgmg
():
model
=
load_pretrained
(
'DGMG_ZINC_canonical'
)
run_dgmg_ZINC
(
model
)
model
=
load_pretrained
(
'DGMG_ZINC_random'
)
run_dgmg_ZINC
(
model
)
model
=
load_pretrained
(
'DGMG_ChEMBL_canonical'
)
run_dgmg_ChEMBL
(
model
)
model
=
load_pretrained
(
'DGMG_ChEMBL_random'
)
run_dgmg_ChEMBL
(
model
)
remove_file
(
'DGMG_ChEMBL_canonical_pre_trained.pth'
)
remove_file
(
'DGMG_ChEMBL_random_pre_trained.pth'
)
remove_file
(
'DGMG_ZINC_canonical_pre_trained.pth'
)
remove_file
(
'DGMG_ZINC_random_pre_trained.pth'
)
def
test_jtnn
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
model
=
load_pretrained
(
'JTNN_ZINC'
).
to
(
device
)
remove_file
(
'JTNN_ZINC_pre_trained.pth'
)
def
test_gcn_tox21
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
node_featurizer
=
CanonicalAtomFeaturizer
()
g1
=
smiles_to_bigraph
(
'CO'
,
node_featurizer
=
node_featurizer
)
g2
=
smiles_to_bigraph
(
'CCO'
,
node_featurizer
=
node_featurizer
)
bg
=
dgl
.
batch
([
g1
,
g2
])
model
=
load_pretrained
(
'GCN_Tox21'
).
to
(
device
)
model
(
bg
.
to
(
device
),
bg
.
ndata
.
pop
(
'h'
).
to
(
device
))
model
.
eval
()
model
(
g1
.
to
(
device
),
g1
.
ndata
.
pop
(
'h'
).
to
(
device
))
remove_file
(
'GCN_Tox21_pre_trained.pth'
)
def
test_gat_tox21
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
node_featurizer
=
CanonicalAtomFeaturizer
()
g1
=
smiles_to_bigraph
(
'CO'
,
node_featurizer
=
node_featurizer
)
g2
=
smiles_to_bigraph
(
'CCO'
,
node_featurizer
=
node_featurizer
)
bg
=
dgl
.
batch
([
g1
,
g2
])
model
=
load_pretrained
(
'GAT_Tox21'
).
to
(
device
)
model
(
bg
.
to
(
device
),
bg
.
ndata
.
pop
(
'h'
).
to
(
device
))
model
.
eval
()
model
(
g1
.
to
(
device
),
g1
.
ndata
.
pop
(
'h'
).
to
(
device
))
remove_file
(
'GAT_Tox21_pre_trained.pth'
)
def
test_weave_tox21
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
node_featurizer
=
WeaveAtomFeaturizer
()
edge_featurizer
=
WeaveEdgeFeaturizer
(
max_distance
=
2
)
g1
=
smiles_to_complete_graph
(
'CO'
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
,
add_self_loop
=
True
)
g2
=
smiles_to_complete_graph
(
'CCO'
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
,
add_self_loop
=
True
)
bg
=
dgl
.
batch
([
g1
,
g2
])
model
=
load_pretrained
(
'Weave_Tox21'
).
to
(
device
)
model
(
bg
.
to
(
device
),
bg
.
ndata
.
pop
(
'h'
).
to
(
device
),
bg
.
edata
.
pop
(
'e'
).
to
(
device
))
model
.
eval
()
model
(
g1
.
to
(
device
),
g1
.
ndata
.
pop
(
'h'
).
to
(
device
),
g1
.
edata
.
pop
(
'e'
).
to
(
device
))
remove_file
(
'Weave_Tox21_pre_trained.pth'
)
def
chirality
(
atom
):
try
:
return
one_hot_encoding
(
atom
.
GetProp
(
'_CIPCode'
),
[
'R'
,
'S'
])
+
\
[
atom
.
HasProp
(
'_ChiralityPossible'
)]
except
:
return
[
False
,
False
]
+
[
atom
.
HasProp
(
'_ChiralityPossible'
)]
def
test_attentivefp_aromaticity
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
node_featurizer
=
BaseAtomFeaturizer
(
featurizer_funcs
=
{
'hv'
:
ConcatFeaturizer
([
partial
(
atom_type_one_hot
,
allowable_set
=
[
'B'
,
'C'
,
'N'
,
'O'
,
'F'
,
'Si'
,
'P'
,
'S'
,
'Cl'
,
'As'
,
'Se'
,
'Br'
,
'Te'
,
'I'
,
'At'
],
encode_unknown
=
True
),
partial
(
atom_degree_one_hot
,
allowable_set
=
list
(
range
(
6
))),
atom_formal_charge
,
atom_num_radical_electrons
,
partial
(
atom_hybridization_one_hot
,
encode_unknown
=
True
),
lambda
atom
:
[
0
],
# A placeholder for aromatic information,
atom_total_num_H_one_hot
,
chirality
],
)}
)
edge_featurizer
=
BaseBondFeaturizer
({
'he'
:
lambda
bond
:
[
0
for
_
in
range
(
10
)]
})
g1
=
smiles_to_bigraph
(
'CO'
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
)
g2
=
smiles_to_bigraph
(
'CCO'
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
)
bg
=
dgl
.
batch
([
g1
,
g2
])
model
=
load_pretrained
(
'AttentiveFP_Aromaticity'
).
to
(
device
)
model
(
bg
.
to
(
device
),
bg
.
ndata
.
pop
(
'hv'
).
to
(
device
),
bg
.
edata
.
pop
(
'he'
).
to
(
device
))
model
.
eval
()
model
(
g1
.
to
(
device
),
g1
.
ndata
.
pop
(
'hv'
).
to
(
device
),
g1
.
edata
.
pop
(
'he'
).
to
(
device
))
remove_file
(
'AttentiveFP_Aromaticity_pre_trained.pth'
)
if
__name__
==
'__main__'
:
test_dgmg
()
test_jtnn
()
test_gcn_tox21
()
test_gat_tox21
()
test_attentivefp_aromaticity
()
apps/life_sci/tests/model/test_property_prediction.py
deleted
100644 → 0
View file @
94c67203
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dgl
import
DGLGraph
from
dgllife.model.model_zoo
import
*
def
test_graph1
():
"""Graph with node features."""
g
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g
,
torch
.
arange
(
g
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
def
test_graph2
():
"""Batched graph with node features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
arange
(
bg
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
def
test_graph3
():
"""Graph with node features and edge features."""
g
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g
,
torch
.
arange
(
g
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
),
\
torch
.
arange
(
2
*
g
.
number_of_edges
()).
float
().
reshape
(
-
1
,
2
)
def
test_graph4
():
"""Batched graph with node features and edge features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
arange
(
bg
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
),
\
torch
.
arange
(
2
*
bg
.
number_of_edges
()).
float
().
reshape
(
-
1
,
2
)
def
test_graph5
():
"""Graph with node types and edge distances."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g1
,
torch
.
LongTensor
([
0
,
1
,
0
]),
torch
.
randn
(
3
,
1
)
def
test_graph6
():
"""Batched graph with node types and edge distances."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
LongTensor
([
0
,
1
,
0
,
2
,
0
,
3
,
4
,
4
]),
torch
.
randn
(
7
,
1
)
def
test_graph7
():
"""Graph with categorical node and edge features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g1
,
torch
.
LongTensor
([
0
,
1
,
0
]),
torch
.
LongTensor
([
2
,
3
,
4
]),
\
torch
.
LongTensor
([
0
,
0
,
1
]),
torch
.
LongTensor
([
2
,
3
,
2
])
def
test_graph8
():
"""Batched graph with categorical node and edge features."""
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
LongTensor
([
0
,
1
,
0
,
2
,
1
,
0
,
2
,
2
]),
\
torch
.
LongTensor
([
2
,
3
,
4
,
1
,
0
,
1
,
2
,
2
]),
\
torch
.
LongTensor
([
0
,
0
,
1
,
2
,
1
,
0
,
0
]),
\
torch
.
LongTensor
([
2
,
3
,
2
,
0
,
1
,
2
,
1
])
def
test_mlp_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g_feats
=
torch
.
tensor
([[
1.
],
[
2.
]]).
to
(
device
)
mlp_predictor
=
MLPPredictor
(
in_feats
=
1
,
hidden_feats
=
1
,
n_tasks
=
2
).
to
(
device
)
assert
mlp_predictor
(
g_feats
).
shape
==
torch
.
Size
([
2
,
2
])
def
test_gcn_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
# Test default setting
gcn_predictor
=
GCNPredictor
(
in_feats
=
1
).
to
(
device
)
gcn_predictor
.
eval
()
assert
gcn_predictor
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
1
])
gcn_predictor
.
train
()
assert
gcn_predictor
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
1
])
# Test configured setting
gcn_predictor
=
GCNPredictor
(
in_feats
=
1
,
hidden_feats
=
[
1
],
activation
=
[
F
.
relu
],
residual
=
[
True
],
batchnorm
=
[
True
],
dropout
=
[
0.1
],
classifier_hidden_feats
=
1
,
classifier_dropout
=
0.1
,
n_tasks
=
2
).
to
(
device
)
gcn_predictor
.
eval
()
assert
gcn_predictor
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
2
])
gcn_predictor
.
train
()
assert
gcn_predictor
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
2
])
def
test_gat_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
# Test default setting
gat_predictor
=
GATPredictor
(
in_feats
=
1
).
to
(
device
)
gat_predictor
.
eval
()
assert
gat_predictor
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
1
])
gat_predictor
.
train
()
assert
gat_predictor
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
1
])
# Test configured setting
gat_predictor
=
GATPredictor
(
in_feats
=
1
,
hidden_feats
=
[
1
,
2
],
num_heads
=
[
2
,
3
],
feat_drops
=
[
0.1
,
0.1
],
attn_drops
=
[
0.1
,
0.1
],
alphas
=
[
0.1
,
0.1
],
residuals
=
[
True
,
True
],
agg_modes
=
[
'mean'
,
'flatten'
],
activations
=
[
None
,
F
.
elu
]).
to
(
device
)
gat_predictor
.
eval
()
assert
gat_predictor
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
1
])
gat_predictor
.
train
()
assert
gat_predictor
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
1
])
def
test_attentivefp_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
,
edge_feats
=
test_graph3
()
g
,
node_feats
,
edge_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
),
edge_feats
.
to
(
device
)
bg
,
batch_node_feats
,
batch_edge_feats
=
test_graph4
()
bg
,
batch_node_feats
,
batch_edge_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
),
\
batch_edge_feats
.
to
(
device
)
attentivefp_predictor
=
AttentiveFPPredictor
(
node_feat_size
=
1
,
edge_feat_size
=
2
,
num_layers
=
2
,
num_timesteps
=
1
,
graph_feat_size
=
1
,
n_tasks
=
2
).
to
(
device
)
assert
attentivefp_predictor
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
1
,
2
])
assert
attentivefp_predictor
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
\
torch
.
Size
([
2
,
2
])
def
test_schnet_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_types
,
edge_dists
=
test_graph5
()
g
,
node_types
,
edge_dists
=
g
.
to
(
device
),
node_types
.
to
(
device
),
edge_dists
.
to
(
device
)
bg
,
batch_node_types
,
batch_edge_dists
=
test_graph6
()
bg
,
batch_node_types
,
batch_edge_dists
=
bg
.
to
(
device
),
batch_node_types
.
to
(
device
),
\
batch_edge_dists
.
to
(
device
)
# Test default setting
schnet_predictor
=
SchNetPredictor
().
to
(
device
)
assert
schnet_predictor
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
1
,
1
])
assert
schnet_predictor
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
\
torch
.
Size
([
2
,
1
])
# Test configured setting
schnet_predictor
=
SchNetPredictor
(
node_feats
=
2
,
hidden_feats
=
[
2
,
2
],
classifier_hidden_feats
=
3
,
n_tasks
=
3
,
num_node_types
=
5
,
cutoff
=
0.3
).
to
(
device
)
assert
schnet_predictor
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
1
,
3
])
assert
schnet_predictor
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
\
torch
.
Size
([
2
,
3
])
def
test_mgcn_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_types
,
edge_dists
=
test_graph5
()
g
,
node_types
,
edge_dists
=
g
.
to
(
device
),
node_types
.
to
(
device
),
edge_dists
.
to
(
device
)
bg
,
batch_node_types
,
batch_edge_dists
=
test_graph6
()
bg
,
batch_node_types
,
batch_edge_dists
=
bg
.
to
(
device
),
batch_node_types
.
to
(
device
),
\
batch_edge_dists
.
to
(
device
)
# Test default setting
mgcn_predictor
=
MGCNPredictor
().
to
(
device
)
assert
mgcn_predictor
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
1
,
1
])
assert
mgcn_predictor
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
\
torch
.
Size
([
2
,
1
])
# Test configured setting
mgcn_predictor
=
MGCNPredictor
(
feats
=
2
,
n_layers
=
2
,
classifier_hidden_feats
=
3
,
n_tasks
=
3
,
num_node_types
=
5
,
num_edge_types
=
150
,
cutoff
=
0.3
).
to
(
device
)
assert
mgcn_predictor
(
g
,
node_types
,
edge_dists
).
shape
==
torch
.
Size
([
1
,
3
])
assert
mgcn_predictor
(
bg
,
batch_node_types
,
batch_edge_dists
).
shape
==
\
torch
.
Size
([
2
,
3
])
def
test_mpnn_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
,
edge_feats
=
test_graph3
()
g
,
node_feats
,
edge_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
),
edge_feats
.
to
(
device
)
bg
,
batch_node_feats
,
batch_edge_feats
=
test_graph4
()
bg
,
batch_node_feats
,
batch_edge_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
),
\
batch_edge_feats
.
to
(
device
)
# Test default setting
mpnn_predictor
=
MPNNPredictor
(
node_in_feats
=
1
,
edge_in_feats
=
2
).
to
(
device
)
assert
mpnn_predictor
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
1
,
1
])
assert
mpnn_predictor
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
\
torch
.
Size
([
2
,
1
])
# Test configured setting
mpnn_predictor
=
MPNNPredictor
(
node_in_feats
=
1
,
edge_in_feats
=
2
,
node_out_feats
=
2
,
edge_hidden_feats
=
2
,
n_tasks
=
2
,
num_step_message_passing
=
2
,
num_step_set2set
=
2
,
num_layer_set2set
=
2
).
to
(
device
)
assert
mpnn_predictor
(
g
,
node_feats
,
edge_feats
).
shape
==
torch
.
Size
([
1
,
2
])
assert
mpnn_predictor
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
\
torch
.
Size
([
2
,
2
])
def
test_weave_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
bg
,
batch_node_feats
,
batch_edge_feats
=
test_graph4
()
bg
,
batch_node_feats
,
batch_edge_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
),
\
batch_edge_feats
.
to
(
device
)
# Test default setting
weave_predictor
=
WeavePredictor
(
node_in_feats
=
1
,
edge_in_feats
=
2
).
to
(
device
)
assert
weave_predictor
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
\
torch
.
Size
([
2
,
1
])
# Test configured setting
weave_predictor
=
WeavePredictor
(
node_in_feats
=
1
,
edge_in_feats
=
2
,
num_gnn_layers
=
2
,
gnn_hidden_feats
=
10
,
gnn_activation
=
F
.
relu
,
graph_feats
=
128
,
gaussian_expand
=
True
,
gaussian_memberships
=
None
,
readout_activation
=
nn
.
Tanh
(),
n_tasks
=
2
).
to
(
device
)
assert
weave_predictor
(
bg
,
batch_node_feats
,
batch_edge_feats
).
shape
==
\
torch
.
Size
([
2
,
2
])
def
test_gin_predictor
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats1
,
node_feats2
,
edge_feats1
,
edge_feats2
=
test_graph7
()
node_feats1
,
node_feats2
=
node_feats1
.
to
(
device
),
node_feats2
.
to
(
device
)
edge_feats1
,
edge_feats2
=
edge_feats1
.
to
(
device
),
edge_feats2
.
to
(
device
)
bg
,
batch_node_feats1
,
batch_node_feats2
,
\
batch_edge_feats1
,
batch_edge_feats2
=
test_graph8
()
batch_node_feats1
,
batch_node_feats2
=
batch_node_feats1
.
to
(
device
),
\
batch_node_feats2
.
to
(
device
)
batch_edge_feats1
,
batch_edge_feats2
=
batch_edge_feats1
.
to
(
device
),
\
batch_edge_feats2
.
to
(
device
)
num_node_emb_list
=
[
3
,
5
]
num_edge_emb_list
=
[
3
,
4
]
for
JK
in
[
'concat'
,
'last'
,
'max'
,
'sum'
]:
for
readout
in
[
'sum'
,
'mean'
,
'max'
,
'attention'
]:
model
=
GINPredictor
(
num_node_emb_list
=
num_node_emb_list
,
num_edge_emb_list
=
num_edge_emb_list
,
num_layers
=
2
,
emb_dim
=
10
,
JK
=
JK
,
readout
=
readout
,
n_tasks
=
2
).
to
(
device
)
assert
model
(
g
,
[
node_feats1
,
node_feats2
],
[
edge_feats1
,
edge_feats2
]).
shape
\
==
torch
.
Size
([
1
,
2
])
assert
model
(
bg
,
[
batch_node_feats1
,
batch_node_feats2
],
[
batch_edge_feats1
,
batch_edge_feats2
]).
shape
==
torch
.
Size
([
2
,
2
])
if
__name__
==
'__main__'
:
test_mlp_predictor
()
test_gcn_predictor
()
test_gat_predictor
()
test_attentivefp_predictor
()
test_schnet_predictor
()
test_mgcn_predictor
()
test_mpnn_predictor
()
test_weave_predictor
()
test_gin_predictor
()
apps/life_sci/tests/model/test_reaction_prediction.py
deleted
100644 → 0
View file @
94c67203
import
dgl
import
numpy
as
np
import
torch
from
dgl
import
DGLGraph
from
dgllife.model.model_zoo
import
*
def
get_complete_graph
(
num_nodes
):
edge_list
=
[]
for
i
in
range
(
num_nodes
):
for
j
in
range
(
num_nodes
):
edge_list
.
append
((
i
,
j
))
return
DGLGraph
(
edge_list
)
def
test_graph1
():
"""
Bi-directed graphs and complete graphs for the molecules.
In addition to node features/edge features, we also return
features for the pairs of nodes.
"""
mol_graph
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
node_feats
=
torch
.
arange
(
mol_graph
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
edge_feats
=
torch
.
arange
(
2
*
mol_graph
.
number_of_edges
()).
float
().
reshape
(
-
1
,
2
)
complete_graph
=
get_complete_graph
(
mol_graph
.
number_of_nodes
())
atom_pair_feats
=
torch
.
arange
(
complete_graph
.
number_of_edges
()).
float
().
reshape
(
-
1
,
1
)
return
mol_graph
,
node_feats
,
edge_feats
,
complete_graph
,
atom_pair_feats
def
test_graph2
():
"""Batched version of test_graph1"""
mol_graph1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
mol_graph2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
batch_mol_graph
=
dgl
.
batch
([
mol_graph1
,
mol_graph2
])
node_feats
=
torch
.
arange
(
batch_mol_graph
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
edge_feats
=
torch
.
arange
(
2
*
batch_mol_graph
.
number_of_edges
()).
float
().
reshape
(
-
1
,
2
)
complete_graph1
=
get_complete_graph
(
mol_graph1
.
number_of_nodes
())
complete_graph2
=
get_complete_graph
(
mol_graph2
.
number_of_nodes
())
batch_complete_graph
=
dgl
.
batch
([
complete_graph1
,
complete_graph2
])
atom_pair_feats
=
torch
.
arange
(
batch_complete_graph
.
number_of_edges
()).
float
().
reshape
(
-
1
,
1
)
return
batch_mol_graph
,
node_feats
,
edge_feats
,
batch_complete_graph
,
atom_pair_feats
def
test_wln_reaction_center
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
mol_graph
,
node_feats
,
edge_feats
,
complete_graph
,
atom_pair_feats
=
test_graph1
()
mol_graph
=
mol_graph
.
to
(
device
)
node_feats
,
edge_feats
=
node_feats
.
to
(
device
),
edge_feats
.
to
(
device
)
complete_graph
=
complete_graph
.
to
(
device
)
atom_pair_feats
=
atom_pair_feats
.
to
(
device
)
batch_mol_graph
,
batch_node_feats
,
batch_edge_feats
,
batch_complete_graph
,
\
batch_atom_pair_feats
=
test_graph2
()
batch_mol_graph
=
batch_mol_graph
.
to
(
device
)
batch_node_feats
,
batch_edge_feats
=
batch_node_feats
.
to
(
device
),
batch_edge_feats
.
to
(
device
)
batch_complete_graph
=
batch_complete_graph
.
to
(
device
)
batch_atom_pair_feats
=
batch_atom_pair_feats
.
to
(
device
)
# Test default setting
model
=
WLNReactionCenter
(
node_in_feats
=
1
,
edge_in_feats
=
2
,
node_pair_in_feats
=
1
).
to
(
device
)
assert
model
(
mol_graph
,
complete_graph
,
node_feats
,
edge_feats
,
atom_pair_feats
)[
0
].
shape
==
\
torch
.
Size
([
complete_graph
.
number_of_edges
(),
5
])
assert
model
(
batch_mol_graph
,
batch_complete_graph
,
batch_node_feats
,
batch_edge_feats
,
batch_atom_pair_feats
)[
0
].
shape
==
\
torch
.
Size
([
batch_complete_graph
.
number_of_edges
(),
5
])
# Test configured setting
model
=
WLNReactionCenter
(
node_in_feats
=
1
,
edge_in_feats
=
2
,
node_pair_in_feats
=
1
,
node_out_feats
=
1
,
n_layers
=
1
,
n_tasks
=
1
).
to
(
device
)
assert
model
(
mol_graph
,
complete_graph
,
node_feats
,
edge_feats
,
atom_pair_feats
)[
0
].
shape
==
\
torch
.
Size
([
complete_graph
.
number_of_edges
(),
1
])
assert
model
(
batch_mol_graph
,
batch_complete_graph
,
batch_node_feats
,
batch_edge_feats
,
batch_atom_pair_feats
)[
0
].
shape
==
\
torch
.
Size
([
batch_complete_graph
.
number_of_edges
(),
1
])
def
test_reactant_product_graph
(
batch_size
,
device
):
edges
=
(
np
.
array
([
0
,
1
,
2
]),
np
.
array
([
1
,
2
,
2
]))
reactant_g
=
[]
for
_
in
range
(
batch_size
):
reactant_g
.
append
(
DGLGraph
(
edges
))
reactant_g
=
dgl
.
batch
(
reactant_g
)
reactant_node_feats
=
torch
.
arange
(
reactant_g
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
).
to
(
device
)
reactant_edge_feats
=
torch
.
arange
(
reactant_g
.
number_of_edges
()).
float
().
reshape
(
-
1
,
1
).
to
(
device
)
product_g
=
[]
batch_num_candidate_products
=
[]
for
i
in
range
(
1
,
batch_size
+
1
):
product_g
.
extend
([
DGLGraph
(
edges
)
for
_
in
range
(
i
)
])
batch_num_candidate_products
.
append
(
i
)
product_g
=
dgl
.
batch
(
product_g
)
product_node_feats
=
torch
.
arange
(
product_g
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
).
to
(
device
)
product_edge_feats
=
torch
.
arange
(
product_g
.
number_of_edges
()).
float
().
reshape
(
-
1
,
1
).
to
(
device
)
product_scores
=
torch
.
randn
(
sum
(
batch_num_candidate_products
),
1
).
to
(
device
)
return
reactant_g
,
reactant_node_feats
,
reactant_edge_feats
,
product_g
,
product_node_feats
,
\
product_edge_feats
,
product_scores
,
batch_num_candidate_products
def
test_wln_candidate_ranking
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
reactant_g
,
reactant_node_feats
,
reactant_edge_feats
,
product_g
,
product_node_feats
,
\
product_edge_feats
,
product_scores
,
num_candidate_products
=
\
test_reactant_product_graph
(
batch_size
=
1
,
device
=
device
)
batch_reactant_g
,
batch_reactant_node_feats
,
batch_reactant_edge_feats
,
batch_product_g
,
\
batch_product_node_feats
,
batch_product_edge_feats
,
batch_product_scores
,
\
batch_num_candidate_products
=
test_reactant_product_graph
(
batch_size
=
2
,
device
=
device
)
# Test default setting
model
=
WLNReactionRanking
(
node_in_feats
=
1
,
edge_in_feats
=
1
).
to
(
device
)
assert
model
(
reactant_g
,
reactant_node_feats
,
reactant_edge_feats
,
product_g
,
product_node_feats
,
product_edge_feats
,
product_scores
,
num_candidate_products
).
shape
==
torch
.
Size
([
sum
(
num_candidate_products
),
1
])
assert
model
(
batch_reactant_g
,
batch_reactant_node_feats
,
batch_reactant_edge_feats
,
batch_product_g
,
batch_product_node_feats
,
batch_product_edge_feats
,
batch_product_scores
,
batch_num_candidate_products
).
shape
==
\
torch
.
Size
([
sum
(
batch_num_candidate_products
),
1
])
model
=
WLNReactionRanking
(
node_in_feats
=
1
,
edge_in_feats
=
1
,
node_hidden_feats
=
100
,
num_encode_gnn_layers
=
2
).
to
(
device
)
assert
model
(
reactant_g
,
reactant_node_feats
,
reactant_edge_feats
,
product_g
,
product_node_feats
,
product_edge_feats
,
product_scores
,
num_candidate_products
).
shape
==
torch
.
Size
([
sum
(
num_candidate_products
),
1
])
assert
model
(
batch_reactant_g
,
batch_reactant_node_feats
,
batch_reactant_edge_feats
,
batch_product_g
,
batch_product_node_feats
,
batch_product_edge_feats
,
batch_product_scores
,
batch_num_candidate_products
).
shape
==
\
torch
.
Size
([
sum
(
batch_num_candidate_products
),
1
])
if
__name__
==
'__main__'
:
test_wln_reaction_center
()
test_wln_candidate_ranking
()
apps/life_sci/tests/model/test_readout.py
deleted
100644 → 0
View file @
94c67203
import
dgl
import
torch
import
torch.nn.functional
as
F
from
dgl
import
DGLGraph
from
dgllife.model.readout
import
*
def
test_graph1
():
"""Graph with node features"""
g
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
return
g
,
torch
.
arange
(
g
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
def
test_graph2
():
"Batched graph with node features"
g1
=
DGLGraph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)])
g2
=
DGLGraph
([(
0
,
1
),
(
1
,
2
),
(
1
,
3
),
(
1
,
4
)])
bg
=
dgl
.
batch
([
g1
,
g2
])
return
bg
,
torch
.
arange
(
bg
.
number_of_nodes
()).
float
().
reshape
(
-
1
,
1
)
def
test_weighted_sum_and_max
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
model
=
WeightedSumAndMax
(
in_feats
=
1
).
to
(
device
)
assert
model
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
2
])
assert
model
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
2
])
def
test_attentive_fp_readout
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
model
=
AttentiveFPReadout
(
feat_size
=
1
,
num_timesteps
=
1
).
to
(
device
)
assert
model
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
1
])
assert
model
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
1
])
def
test_mlp_readout
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
model
=
MLPNodeReadout
(
node_feats
=
1
,
hidden_feats
=
2
,
graph_feats
=
3
,
activation
=
F
.
relu
,
mode
=
'sum'
).
to
(
device
)
assert
model
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
3
])
assert
model
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
3
])
model
=
MLPNodeReadout
(
node_feats
=
1
,
hidden_feats
=
2
,
graph_feats
=
3
,
mode
=
'max'
).
to
(
device
)
assert
model
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
3
])
assert
model
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
3
])
model
=
MLPNodeReadout
(
node_feats
=
1
,
hidden_feats
=
2
,
graph_feats
=
3
,
mode
=
'mean'
).
to
(
device
)
assert
model
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
3
])
assert
model
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
3
])
def
test_weave_readout
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda:0'
)
else
:
device
=
torch
.
device
(
'cpu'
)
g
,
node_feats
=
test_graph1
()
g
,
node_feats
=
g
.
to
(
device
),
node_feats
.
to
(
device
)
bg
,
batch_node_feats
=
test_graph2
()
bg
,
batch_node_feats
=
bg
.
to
(
device
),
batch_node_feats
.
to
(
device
)
model
=
WeaveGather
(
node_in_feats
=
1
).
to
(
device
)
assert
model
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
1
])
assert
model
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
1
])
model
=
WeaveGather
(
node_in_feats
=
1
,
gaussian_expand
=
False
).
to
(
device
)
assert
model
(
g
,
node_feats
).
shape
==
torch
.
Size
([
1
,
1
])
assert
model
(
bg
,
batch_node_feats
).
shape
==
torch
.
Size
([
2
,
1
])
if
__name__
==
'__main__'
:
test_weighted_sum_and_max
()
test_attentive_fp_readout
()
test_mlp_readout
()
test_weave_readout
()
apps/life_sci/tests/scripts/build.sh
deleted
100644 → 0
View file @
94c67203
#!/bin/bash
# Argument
# - dev: cpu or gpu
if
[
$#
-ne
1
]
;
then
echo
"Device argument required, can be cpu or gpu"
exit
-1
fi
dev
=
$1
set
-e
.
/opt/conda/etc/profile.d/conda.sh
rm
-rf
_deps
mkdir
_deps
pushd
_deps
conda activate
"pytorch-ci"
if
[
"
$dev
"
==
"gpu"
]
;
then
pip uninstall
-y
dgl
pip
install
--pre
dgl
python3 setup.py
install
else
pip uninstall
-y
dgl-cu101
pip
install
--pre
dgl-cu101
python3 setup.py
install
fi
popd
\ No newline at end of file
apps/life_sci/tests/scripts/task_lint.sh
deleted
100644 → 0
View file @
94c67203
#!/bin/bash
# Adapted from github.com/dmlc/dgl/tests/scripts/task_lint.sh
# pylint
echo
'Checking code style of python codes...'
python3
-m
pylint
--reports
=
y
-v
--rcfile
=
tests/lint/pylintrc python/dgllife
||
exit
1
\ No newline at end of file
apps/life_sci/tests/scripts/task_unit_test.sh
deleted
100644 → 0
View file @
94c67203
#!/bin/bash
.
/opt/conda/etc/profile.d/conda.sh
function
fail
{
echo
FAIL:
$@
exit
-1
}
function
usage
{
echo
"Usage:
$0
backend device"
}
if
[
$#
-ne
2
]
;
then
usage
fail
"Error: must specify backend and device"
fi
export
DGLBACKEND
=
$1
export
PYTHONPATH
=
${
PWD
}
/python:
$PYTHONPATH
export
DGL_DOWNLOAD_DIR
=
${
PWD
}
if
[
$2
==
"gpu"
]
then
export
CUDA_VISIBLE_DEVICES
=
0
else
export
CUDA_VISIBLE_DEVICES
=
-1
fi
conda activate
${
DGLBACKEND
}
-ci
pip
install
_deps/dgl
*
.whl
python3
-m
pytest
-v
--junitxml
=
pytest_data.xml tests/data
||
fail
"data"
python3
-m
pytest
-v
--junitxml
=
pytest_model.xml tests/model
||
fail
"model"
python3
-m
pytest
-v
--junitxml
=
pytest_utils.xml tests/utils
||
fail
"utils"
\ No newline at end of file
apps/life_sci/tests/utils/test_complex_to_graph.py
deleted
100644 → 0
View file @
94c67203
import
os
import
shutil
import
torch
from
dgl.data.utils
import
download
,
_get_dgl_url
,
extract_archive
from
dgllife.utils.complex_to_graph
import
*
from
dgllife.utils.rdkit_utils
import
load_molecule
def
remove_dir
(
dir
):
if
os
.
path
.
isdir
(
dir
):
try
:
shutil
.
rmtree
(
dir
)
except
OSError
:
pass
def
test_acnn_graph_construction_and_featurization
():
remove_dir
(
'tmp1'
)
remove_dir
(
'tmp2'
)
url
=
_get_dgl_url
(
'dgllife/example_mols.tar.gz'
)
local_path
=
'tmp1/example_mols.tar.gz'
download
(
url
,
path
=
local_path
)
extract_archive
(
local_path
,
'tmp2'
)
pocket_mol
,
pocket_coords
=
load_molecule
(
'tmp2/example_mols/example.pdb'
,
remove_hs
=
True
)
ligand_mol
,
ligand_coords
=
load_molecule
(
'tmp2/example_mols/example.pdbqt'
,
remove_hs
=
True
)
pocket_mol_with_h
,
pocket_coords_with_h
=
load_molecule
(
'tmp2/example_mols/example.pdb'
,
remove_hs
=
False
)
remove_dir
(
'tmp1'
)
remove_dir
(
'tmp2'
)
# Test default case
g
=
ACNN_graph_construction_and_featurization
(
ligand_mol
,
pocket_mol
,
ligand_coords
,
pocket_coords
)
assert
set
(
g
.
ntypes
)
==
set
([
'protein_atom'
,
'ligand_atom'
])
assert
set
(
g
.
etypes
)
==
set
([
'protein'
,
'ligand'
,
'complex'
,
'complex'
,
'complex'
,
'complex'
])
assert
g
.
number_of_nodes
(
'protein_atom'
)
==
286
assert
g
.
number_of_nodes
(
'ligand_atom'
)
==
21
assert
g
.
number_of_edges
(
'protein'
)
==
3432
assert
g
.
number_of_edges
(
'ligand'
)
==
252
assert
g
.
number_of_edges
((
'protein_atom'
,
'complex'
,
'protein_atom'
))
==
3349
assert
g
.
number_of_edges
((
'ligand_atom'
,
'complex'
,
'ligand_atom'
))
==
131
assert
g
.
number_of_edges
((
'protein_atom'
,
'complex'
,
'ligand_atom'
))
==
121
assert
g
.
number_of_edges
((
'ligand_atom'
,
'complex'
,
'protein_atom'
))
==
83
assert
'atomic_number'
in
g
.
nodes
[
'protein_atom'
].
data
assert
'atomic_number'
in
g
.
nodes
[
'ligand_atom'
].
data
assert
torch
.
allclose
(
g
.
nodes
[
'protein_atom'
].
data
[
'mask'
],
torch
.
ones
(
g
.
number_of_nodes
(
'protein_atom'
),
1
))
assert
torch
.
allclose
(
g
.
nodes
[
'ligand_atom'
].
data
[
'mask'
],
torch
.
ones
(
g
.
number_of_nodes
(
'ligand_atom'
),
1
))
assert
'distance'
in
g
.
edges
[
'protein'
].
data
assert
'distance'
in
g
.
edges
[
'ligand'
].
data
assert
'distance'
in
g
.
edges
[(
'protein_atom'
,
'complex'
,
'protein_atom'
)].
data
assert
'distance'
in
g
.
edges
[(
'ligand_atom'
,
'complex'
,
'ligand_atom'
)].
data
assert
'distance'
in
g
.
edges
[(
'protein_atom'
,
'complex'
,
'ligand_atom'
)].
data
assert
'distance'
in
g
.
edges
[(
'ligand_atom'
,
'complex'
,
'protein_atom'
)].
data
# Test max_num_ligand_atoms and max_num_protein_atoms
max_num_ligand_atoms
=
30
max_num_protein_atoms
=
300
g
=
ACNN_graph_construction_and_featurization
(
ligand_mol
,
pocket_mol
,
ligand_coords
,
pocket_coords
,
max_num_ligand_atoms
=
max_num_ligand_atoms
,
max_num_protein_atoms
=
max_num_protein_atoms
)
assert
g
.
number_of_nodes
(
'ligand_atom'
)
==
max_num_ligand_atoms
assert
g
.
number_of_nodes
(
'protein_atom'
)
==
max_num_protein_atoms
ligand_mask
=
torch
.
zeros
(
max_num_ligand_atoms
,
1
)
ligand_mask
[:
ligand_mol
.
GetNumAtoms
(),
:]
=
1.
assert
torch
.
allclose
(
ligand_mask
,
g
.
nodes
[
'ligand_atom'
].
data
[
'mask'
])
protein_mask
=
torch
.
zeros
(
max_num_protein_atoms
,
1
)
protein_mask
[:
pocket_mol
.
GetNumAtoms
(),
:]
=
1.
assert
torch
.
allclose
(
protein_mask
,
g
.
nodes
[
'protein_atom'
].
data
[
'mask'
])
# Test neighbor_cutoff
neighbor_cutoff
=
6.
g
=
ACNN_graph_construction_and_featurization
(
ligand_mol
,
pocket_mol
,
ligand_coords
,
pocket_coords
,
neighbor_cutoff
=
neighbor_cutoff
)
assert
g
.
number_of_edges
(
'protein'
)
==
3405
assert
g
.
number_of_edges
(
'ligand'
)
==
193
assert
g
.
number_of_edges
((
'protein_atom'
,
'complex'
,
'protein_atom'
))
==
3331
assert
g
.
number_of_edges
((
'ligand_atom'
,
'complex'
,
'ligand_atom'
))
==
123
assert
g
.
number_of_edges
((
'protein_atom'
,
'complex'
,
'ligand_atom'
))
==
119
assert
g
.
number_of_edges
((
'ligand_atom'
,
'complex'
,
'protein_atom'
))
==
82
# Test max_num_neighbors
g
=
ACNN_graph_construction_and_featurization
(
ligand_mol
,
pocket_mol
,
ligand_coords
,
pocket_coords
,
max_num_neighbors
=
6
)
assert
g
.
number_of_edges
(
'protein'
)
==
1716
assert
g
.
number_of_edges
(
'ligand'
)
==
126
assert
g
.
number_of_edges
((
'protein_atom'
,
'complex'
,
'protein_atom'
))
==
1691
assert
g
.
number_of_edges
((
'ligand_atom'
,
'complex'
,
'ligand_atom'
))
==
86
assert
g
.
number_of_edges
((
'protein_atom'
,
'complex'
,
'ligand_atom'
))
==
40
assert
g
.
number_of_edges
((
'ligand_atom'
,
'complex'
,
'protein_atom'
))
==
25
# Test strip_hydrogens
g
=
ACNN_graph_construction_and_featurization
(
pocket_mol_with_h
,
pocket_mol_with_h
,
pocket_coords_with_h
,
pocket_coords_with_h
,
strip_hydrogens
=
True
)
assert
g
.
number_of_nodes
(
'ligand_atom'
)
!=
pocket_mol_with_h
.
GetNumAtoms
()
assert
g
.
number_of_nodes
(
'protein_atom'
)
!=
pocket_mol_with_h
.
GetNumAtoms
()
non_h_atomic_numbers
=
[]
for
i
in
range
(
pocket_mol_with_h
.
GetNumAtoms
()):
atom
=
pocket_mol_with_h
.
GetAtomWithIdx
(
i
)
if
atom
.
GetSymbol
()
!=
'H'
:
non_h_atomic_numbers
.
append
(
atom
.
GetAtomicNum
())
non_h_atomic_numbers
=
torch
.
tensor
(
non_h_atomic_numbers
).
float
().
reshape
(
-
1
,
1
)
assert
torch
.
allclose
(
non_h_atomic_numbers
,
g
.
nodes
[
'ligand_atom'
].
data
[
'atomic_number'
])
assert
torch
.
allclose
(
non_h_atomic_numbers
,
g
.
nodes
[
'protein_atom'
].
data
[
'atomic_number'
])
if
__name__
==
'__main__'
:
test_acnn_graph_construction_and_featurization
()
apps/life_sci/tests/utils/test_early_stop.py
deleted
100644 → 0
View file @
94c67203
import
os
import
torch
import
torch.nn
as
nn
from
dgllife.utils
import
EarlyStopping
def
remove_file
(
fname
):
if
os
.
path
.
isfile
(
fname
):
try
:
os
.
remove
(
fname
)
except
OSError
:
pass
def
test_early_stopping_high
():
model1
=
nn
.
Linear
(
2
,
3
)
stopper
=
EarlyStopping
(
mode
=
'higher'
,
patience
=
1
,
filename
=
'test.pkl'
)
# Save model in the first step
stopper
.
step
(
1.
,
model1
)
model1
.
weight
.
data
=
model1
.
weight
.
data
+
1
model2
=
nn
.
Linear
(
2
,
3
)
stopper
.
load_checkpoint
(
model2
)
assert
not
torch
.
allclose
(
model1
.
weight
,
model2
.
weight
)
# Save model checkpoint with performance improvement
model1
.
weight
.
data
=
model1
.
weight
.
data
+
1
stopper
.
step
(
2.
,
model1
)
stopper
.
load_checkpoint
(
model2
)
assert
torch
.
allclose
(
model1
.
weight
,
model2
.
weight
)
# Stop when no improvement observed
model1
.
weight
.
data
=
model1
.
weight
.
data
+
1
assert
stopper
.
step
(
0.5
,
model1
)
stopper
.
load_checkpoint
(
model2
)
assert
not
torch
.
allclose
(
model1
.
weight
,
model2
.
weight
)
remove_file
(
'test.pkl'
)
def
test_early_stopping_low
():
model1
=
nn
.
Linear
(
2
,
3
)
stopper
=
EarlyStopping
(
mode
=
'lower'
,
patience
=
1
,
filename
=
'test.pkl'
)
# Save model in the first step
stopper
.
step
(
1.
,
model1
)
model1
.
weight
.
data
=
model1
.
weight
.
data
+
1
model2
=
nn
.
Linear
(
2
,
3
)
stopper
.
load_checkpoint
(
model2
)
assert
not
torch
.
allclose
(
model1
.
weight
,
model2
.
weight
)
# Save model checkpoint with performance improvement
model1
.
weight
.
data
=
model1
.
weight
.
data
+
1
stopper
.
step
(
0.5
,
model1
)
stopper
.
load_checkpoint
(
model2
)
assert
torch
.
allclose
(
model1
.
weight
,
model2
.
weight
)
# Stop when no improvement observed
model1
.
weight
.
data
=
model1
.
weight
.
data
+
1
assert
stopper
.
step
(
2
,
model1
)
stopper
.
load_checkpoint
(
model2
)
assert
not
torch
.
allclose
(
model1
.
weight
,
model2
.
weight
)
remove_file
(
'test.pkl'
)
if
__name__
==
'__main__'
:
test_early_stopping_high
()
test_early_stopping_low
()
apps/life_sci/tests/utils/test_eval.py
deleted
100644 → 0
View file @
94c67203
import
numpy
as
np
import
torch
from
dgllife.utils.eval
import
*
def
test_Meter
():
label
=
torch
.
tensor
([[
0.
,
1.
],
[
0.
,
1.
],
[
1.
,
0.
]])
pred
=
torch
.
tensor
([[
0.5
,
0.5
],
[
0.
,
1.
],
[
1.
,
0.
]])
mask
=
torch
.
tensor
([[
1.
,
0.
],
[
0.
,
1.
],
[
1.
,
1.
]])
label_mean
,
label_std
=
label
.
mean
(
dim
=
0
),
label
.
std
(
dim
=
0
)
# pearson r2
meter
=
Meter
(
label_mean
,
label_std
)
meter
.
update
(
pred
,
label
)
true_scores
=
[
0.7500000774286983
,
0.7500000516191412
]
assert
meter
.
pearson_r2
()
==
true_scores
assert
meter
.
pearson_r2
(
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
pearson_r2
(
'sum'
)
==
np
.
sum
(
true_scores
)
assert
meter
.
compute_metric
(
'r2'
)
==
true_scores
assert
meter
.
compute_metric
(
'r2'
,
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
compute_metric
(
'r2'
,
'sum'
)
==
np
.
sum
(
true_scores
)
meter
=
Meter
(
label_mean
,
label_std
)
meter
.
update
(
pred
,
label
,
mask
)
true_scores
=
[
1.0
,
1.0
]
assert
meter
.
pearson_r2
()
==
true_scores
assert
meter
.
pearson_r2
(
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
pearson_r2
(
'sum'
)
==
np
.
sum
(
true_scores
)
assert
meter
.
compute_metric
(
'r2'
)
==
true_scores
assert
meter
.
compute_metric
(
'r2'
,
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
compute_metric
(
'r2'
,
'sum'
)
==
np
.
sum
(
true_scores
)
# mae
meter
=
Meter
()
meter
.
update
(
pred
,
label
)
true_scores
=
[
0.1666666716337204
,
0.1666666716337204
]
assert
meter
.
mae
()
==
true_scores
assert
meter
.
mae
(
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
mae
(
'sum'
)
==
np
.
sum
(
true_scores
)
assert
meter
.
compute_metric
(
'mae'
)
==
true_scores
assert
meter
.
compute_metric
(
'mae'
,
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
compute_metric
(
'mae'
,
'sum'
)
==
np
.
sum
(
true_scores
)
meter
=
Meter
()
meter
.
update
(
pred
,
label
,
mask
)
true_scores
=
[
0.25
,
0.0
]
assert
meter
.
mae
()
==
true_scores
assert
meter
.
mae
(
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
mae
(
'sum'
)
==
np
.
sum
(
true_scores
)
assert
meter
.
compute_metric
(
'mae'
)
==
true_scores
assert
meter
.
compute_metric
(
'mae'
,
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
compute_metric
(
'mae'
,
'sum'
)
==
np
.
sum
(
true_scores
)
# rmsef
meter
=
Meter
(
label_mean
,
label_std
)
meter
.
update
(
pred
,
label
)
true_scores
=
[
0.41068359261794546
,
0.4106836107598449
]
assert
torch
.
allclose
(
torch
.
tensor
(
meter
.
rmse
()),
torch
.
tensor
(
true_scores
))
assert
torch
.
allclose
(
torch
.
tensor
(
meter
.
compute_metric
(
'rmse'
)),
torch
.
tensor
(
true_scores
))
meter
=
Meter
(
label_mean
,
label_std
)
meter
.
update
(
pred
,
label
,
mask
)
true_scores
=
[
0.44433766459035057
,
0.5019903799993205
]
assert
torch
.
allclose
(
torch
.
tensor
(
meter
.
rmse
()),
torch
.
tensor
(
true_scores
))
assert
torch
.
allclose
(
torch
.
tensor
(
meter
.
compute_metric
(
'rmse'
)),
torch
.
tensor
(
true_scores
))
# roc auc score
meter
=
Meter
()
meter
.
update
(
pred
,
label
)
true_scores
=
[
1.0
,
1.0
]
assert
meter
.
roc_auc_score
()
==
true_scores
assert
meter
.
roc_auc_score
(
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
roc_auc_score
(
'sum'
)
==
np
.
sum
(
true_scores
)
assert
meter
.
compute_metric
(
'roc_auc_score'
)
==
true_scores
assert
meter
.
compute_metric
(
'roc_auc_score'
,
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
compute_metric
(
'roc_auc_score'
,
'sum'
)
==
np
.
sum
(
true_scores
)
meter
=
Meter
()
meter
.
update
(
pred
,
label
,
mask
)
true_scores
=
[
1.0
,
1.0
]
assert
meter
.
roc_auc_score
()
==
true_scores
assert
meter
.
roc_auc_score
(
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
roc_auc_score
(
'sum'
)
==
np
.
sum
(
true_scores
)
assert
meter
.
compute_metric
(
'roc_auc_score'
)
==
true_scores
assert
meter
.
compute_metric
(
'roc_auc_score'
,
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
compute_metric
(
'roc_auc_score'
,
'sum'
)
==
np
.
sum
(
true_scores
)
def
test_cases_with_undefined_scores
():
label
=
torch
.
tensor
([[
0.
,
1.
],
[
0.
,
1.
],
[
1.
,
1.
]])
pred
=
torch
.
tensor
([[
0.5
,
0.5
],
[
0.
,
1.
],
[
1.
,
0.
]])
meter
=
Meter
()
meter
.
update
(
pred
,
label
)
true_scores
=
[
1.0
]
assert
meter
.
roc_auc_score
()
==
true_scores
assert
meter
.
roc_auc_score
(
'mean'
)
==
np
.
mean
(
true_scores
)
assert
meter
.
roc_auc_score
(
'sum'
)
==
np
.
sum
(
true_scores
)
if
__name__
==
'__main__'
:
test_Meter
()
test_cases_with_undefined_scores
()
apps/life_sci/tests/utils/test_featurizers.py
deleted
100644 → 0
View file @
94c67203
import
torch
from
dgllife.utils.featurizers
import
*
from
rdkit
import
Chem
def
test_one_hot_encoding
():
x
=
1.
allowable_set
=
[
0.
,
1.
,
2.
]
assert
one_hot_encoding
(
x
,
allowable_set
)
==
[
0
,
1
,
0
]
assert
one_hot_encoding
(
x
,
allowable_set
,
encode_unknown
=
True
)
==
[
0
,
1
,
0
,
0
]
assert
one_hot_encoding
(
x
,
allowable_set
)
==
[
0
,
1
,
0
,
0
]
assert
one_hot_encoding
(
x
,
allowable_set
,
encode_unknown
=
True
)
==
[
0
,
1
,
0
,
0
]
assert
one_hot_encoding
(
-
1
,
allowable_set
,
encode_unknown
=
True
)
==
[
0
,
0
,
0
,
1
]
def
test_mol1
():
return
Chem
.
MolFromSmiles
(
'CCO'
)
def
test_mol2
():
return
Chem
.
MolFromSmiles
(
'C1=CC2=CC=CC=CC2=C1'
)
def
test_mol3
():
return
Chem
.
MolFromSmiles
(
'O=C(O)/C=C/C(=O)O'
)
def
test_atom_type_one_hot
():
mol
=
test_mol1
()
assert
atom_type_one_hot
(
mol
.
GetAtomWithIdx
(
0
),
[
'C'
,
'O'
])
==
[
1
,
0
]
assert
atom_type_one_hot
(
mol
.
GetAtomWithIdx
(
2
),
[
'C'
,
'O'
])
==
[
0
,
1
]
def
test_atomic_number_one_hot
():
mol
=
test_mol1
()
assert
atomic_number_one_hot
(
mol
.
GetAtomWithIdx
(
0
),
[
6
,
8
])
==
[
1
,
0
]
assert
atomic_number_one_hot
(
mol
.
GetAtomWithIdx
(
2
),
[
6
,
8
])
==
[
0
,
1
]
def
test_atomic_number
():
mol
=
test_mol1
()
assert
atomic_number
(
mol
.
GetAtomWithIdx
(
0
))
==
[
6
]
assert
atomic_number
(
mol
.
GetAtomWithIdx
(
2
))
==
[
8
]
def
test_atom_degree_one_hot
():
mol
=
test_mol1
()
assert
atom_degree_one_hot
(
mol
.
GetAtomWithIdx
(
0
),
[
0
,
1
,
2
])
==
[
0
,
1
,
0
]
assert
atom_degree_one_hot
(
mol
.
GetAtomWithIdx
(
1
),
[
0
,
1
,
2
])
==
[
0
,
0
,
1
]
def
test_atom_degree
():
mol
=
test_mol1
()
assert
atom_degree
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
]
assert
atom_degree
(
mol
.
GetAtomWithIdx
(
1
))
==
[
2
]
def
test_atom_total_degree_one_hot
():
mol
=
test_mol1
()
assert
atom_total_degree_one_hot
(
mol
.
GetAtomWithIdx
(
0
),
[
0
,
2
,
4
])
==
[
0
,
0
,
1
]
assert
atom_total_degree_one_hot
(
mol
.
GetAtomWithIdx
(
2
),
[
0
,
2
,
4
])
==
[
0
,
1
,
0
]
def
test_atom_total_degree
():
mol
=
test_mol1
()
assert
atom_total_degree
(
mol
.
GetAtomWithIdx
(
0
))
==
[
4
]
assert
atom_total_degree
(
mol
.
GetAtomWithIdx
(
2
))
==
[
2
]
def
test_atom_explicit_valence_one_hot
():
mol
=
test_mol1
()
assert
atom_implicit_valence_one_hot
(
mol
.
GetAtomWithIdx
(
0
),
[
1
,
2
,
3
])
==
[
1
,
0
,
0
]
assert
atom_implicit_valence_one_hot
(
mol
.
GetAtomWithIdx
(
1
),
[
1
,
2
,
3
])
==
[
0
,
1
,
0
]
def
test_atom_explicit_valence
():
mol
=
test_mol1
()
assert
atom_explicit_valence
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
]
assert
atom_explicit_valence
(
mol
.
GetAtomWithIdx
(
1
))
==
[
2
]
def
test_atom_implicit_valence_one_hot
():
mol
=
test_mol1
()
assert
atom_implicit_valence_one_hot
(
mol
.
GetAtomWithIdx
(
0
),
[
1
,
2
,
3
])
==
[
0
,
0
,
1
]
assert
atom_implicit_valence_one_hot
(
mol
.
GetAtomWithIdx
(
1
),
[
1
,
2
,
3
])
==
[
0
,
1
,
0
]
def
test_atom_implicit_valence
():
mol
=
test_mol1
()
assert
atom_implicit_valence
(
mol
.
GetAtomWithIdx
(
0
))
==
[
3
]
assert
atom_implicit_valence
(
mol
.
GetAtomWithIdx
(
1
))
==
[
2
]
def
test_atom_hybridization_one_hot
():
mol
=
test_mol1
()
assert
atom_hybridization_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
,
0
,
1
,
0
,
0
]
def
test_atom_total_num_H_one_hot
():
mol
=
test_mol1
()
assert
atom_total_num_H_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
,
0
,
0
,
1
,
0
]
assert
atom_total_num_H_one_hot
(
mol
.
GetAtomWithIdx
(
1
))
==
[
0
,
0
,
1
,
0
,
0
]
def
test_atom_total_num_H
():
mol
=
test_mol1
()
assert
atom_total_num_H
(
mol
.
GetAtomWithIdx
(
0
))
==
[
3
]
assert
atom_total_num_H
(
mol
.
GetAtomWithIdx
(
1
))
==
[
2
]
def
test_atom_formal_charge_one_hot
():
mol
=
test_mol1
()
assert
atom_formal_charge_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
,
0
,
1
,
0
,
0
]
def
test_atom_formal_charge
():
mol
=
test_mol1
()
assert
atom_formal_charge
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
]
def
test_atom_num_radical_electrons_one_hot
():
mol
=
test_mol1
()
assert
atom_num_radical_electrons_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
,
0
,
0
,
0
,
0
]
def
test_atom_num_radical_electrons
():
mol
=
test_mol1
()
assert
atom_num_radical_electrons
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
]
def
test_atom_is_aromatic_one_hot
():
mol
=
test_mol1
()
assert
atom_is_aromatic_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
,
0
]
mol
=
test_mol2
()
assert
atom_is_aromatic_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
,
1
]
def
test_atom_is_aromatic
():
mol
=
test_mol1
()
assert
atom_is_aromatic
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
]
mol
=
test_mol2
()
assert
atom_is_aromatic
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
]
def
test_atom_is_in_ring_one_hot
():
mol
=
test_mol1
()
assert
atom_is_in_ring_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
,
0
]
mol
=
test_mol2
()
assert
atom_is_in_ring_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
,
1
]
def
test_atom_is_in_ring
():
mol
=
test_mol1
()
assert
atom_is_in_ring
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
]
mol
=
test_mol2
()
assert
atom_is_in_ring
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
]
def
test_atom_chiral_tag_one_hot
():
mol
=
test_mol1
()
assert
atom_chiral_tag_one_hot
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
,
0
,
0
,
0
]
def
test_atom_mass
():
mol
=
test_mol1
()
atom
=
mol
.
GetAtomWithIdx
(
0
)
assert
atom_mass
(
atom
)
==
[
atom
.
GetMass
()
*
0.01
]
atom
=
mol
.
GetAtomWithIdx
(
1
)
assert
atom_mass
(
atom
)
==
[
atom
.
GetMass
()
*
0.01
]
def
test_concat_featurizer
():
test_featurizer
=
ConcatFeaturizer
(
[
atom_is_aromatic_one_hot
,
atom_chiral_tag_one_hot
]
)
mol
=
test_mol1
()
assert
test_featurizer
(
mol
.
GetAtomWithIdx
(
0
))
==
[
1
,
0
,
1
,
0
,
0
,
0
]
mol
=
test_mol2
()
assert
test_featurizer
(
mol
.
GetAtomWithIdx
(
0
))
==
[
0
,
1
,
1
,
0
,
0
,
0
]
class
TestAtomFeaturizer
(
BaseAtomFeaturizer
):
def
__init__
(
self
):
super
(
TestAtomFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
'h1'
:
ConcatFeaturizer
([
atom_total_degree_one_hot
,
atom_formal_charge_one_hot
]),
'h2'
:
ConcatFeaturizer
([
atom_num_radical_electrons_one_hot
])
}
)
def
test_base_atom_featurizer
():
test_featurizer
=
TestAtomFeaturizer
()
assert
test_featurizer
.
feat_size
(
'h1'
)
==
11
assert
test_featurizer
.
feat_size
(
'h2'
)
==
5
mol
=
test_mol1
()
feats
=
test_featurizer
(
mol
)
assert
torch
.
allclose
(
feats
[
'h1'
],
torch
.
tensor
([[
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
],
[
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
],
[
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
]]))
assert
torch
.
allclose
(
feats
[
'h2'
],
torch
.
tensor
([[
1.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
]]))
def
test_canonical_atom_featurizer
():
test_featurizer
=
CanonicalAtomFeaturizer
()
assert
test_featurizer
.
feat_size
()
==
74
assert
test_featurizer
.
feat_size
(
'h'
)
==
74
mol
=
test_mol1
()
feats
=
test_featurizer
(
mol
)
assert
list
(
feats
.
keys
())
==
[
'h'
]
assert
torch
.
allclose
(
feats
[
'h'
],
torch
.
tensor
([[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
],
[
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
]]))
def
test_weave_atom_featurizer
():
featurizer
=
WeaveAtomFeaturizer
()
assert
featurizer
.
feat_size
()
==
27
mol
=
test_mol1
()
feats
=
featurizer
(
mol
)
assert
list
(
feats
.
keys
())
==
[
'h'
]
assert
torch
.
allclose
(
feats
[
'h'
],
torch
.
tensor
([[
0.0000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
-
0.0418
,
0.0000
,
0.0000
,
0.0000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0402
,
0.0000
,
0.0000
,
0.0000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
-
0.3967
,
0.0000
,
0.0000
,
0.0000
,
1.0000
,
1.0000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
]]),
rtol
=
1e-3
)
def
test_pretrain_atom_featurizer
():
featurizer
=
PretrainAtomFeaturizer
()
mol
=
test_mol1
()
feats
=
featurizer
(
mol
)
assert
list
(
feats
.
keys
())
==
[
'atomic_number'
,
'chirality_type'
]
assert
torch
.
allclose
(
feats
[
'atomic_number'
],
torch
.
tensor
([[
5
,
5
,
7
]]))
assert
torch
.
allclose
(
feats
[
'chirality_type'
],
torch
.
tensor
([[
0
,
0
,
0
]]))
def
test_bond_type_one_hot
():
mol
=
test_mol1
()
assert
bond_type_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
1
,
0
,
0
,
0
]
mol
=
test_mol2
()
assert
bond_type_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
0
,
0
,
0
,
1
]
def
test_bond_is_conjugated_one_hot
():
mol
=
test_mol1
()
assert
bond_is_conjugated_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
1
,
0
]
mol
=
test_mol2
()
assert
bond_is_conjugated_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
0
,
1
]
def
test_bond_is_conjugated
():
mol
=
test_mol1
()
assert
bond_is_conjugated
(
mol
.
GetBondWithIdx
(
0
))
==
[
0
]
mol
=
test_mol2
()
assert
bond_is_conjugated
(
mol
.
GetBondWithIdx
(
0
))
==
[
1
]
def
test_bond_is_in_ring_one_hot
():
mol
=
test_mol1
()
assert
bond_is_in_ring_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
1
,
0
]
mol
=
test_mol2
()
assert
bond_is_in_ring_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
0
,
1
]
def
test_bond_is_in_ring
():
mol
=
test_mol1
()
assert
bond_is_in_ring
(
mol
.
GetBondWithIdx
(
0
))
==
[
0
]
mol
=
test_mol2
()
assert
bond_is_in_ring
(
mol
.
GetBondWithIdx
(
0
))
==
[
1
]
def
test_bond_stereo_one_hot
():
mol
=
test_mol1
()
assert
bond_stereo_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
1
,
0
,
0
,
0
,
0
,
0
]
def
test_bond_direction_one_hot
():
mol
=
test_mol3
()
assert
bond_direction_one_hot
(
mol
.
GetBondWithIdx
(
0
))
==
[
1
,
0
,
0
]
assert
bond_direction_one_hot
(
mol
.
GetBondWithIdx
(
2
))
==
[
0
,
1
,
0
]
class
TestBondFeaturizer
(
BaseBondFeaturizer
):
def
__init__
(
self
):
super
(
TestBondFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
'h1'
:
ConcatFeaturizer
([
bond_is_in_ring
,
bond_is_conjugated
]),
'h2'
:
ConcatFeaturizer
([
bond_stereo_one_hot
])
}
)
def
test_base_bond_featurizer
():
test_featurizer
=
TestBondFeaturizer
()
assert
test_featurizer
.
feat_size
(
'h1'
)
==
2
assert
test_featurizer
.
feat_size
(
'h2'
)
==
6
mol
=
test_mol1
()
feats
=
test_featurizer
(
mol
)
assert
torch
.
allclose
(
feats
[
'h1'
],
torch
.
tensor
([[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]]))
assert
torch
.
allclose
(
feats
[
'h2'
],
torch
.
tensor
([[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
]]))
def
test_canonical_bond_featurizer
():
test_featurizer
=
CanonicalBondFeaturizer
()
assert
test_featurizer
.
feat_size
()
==
12
assert
test_featurizer
.
feat_size
(
'e'
)
==
12
mol
=
test_mol1
()
feats
=
test_featurizer
(
mol
)
assert
torch
.
allclose
(
feats
[
'e'
],
torch
.
tensor
(
[[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
]]))
def
test_weave_edge_featurizer
():
test_featurizer
=
WeaveEdgeFeaturizer
()
assert
test_featurizer
.
feat_size
()
==
12
mol
=
test_mol1
()
feats
=
test_featurizer
(
mol
)
assert
torch
.
allclose
(
feats
[
'e'
],
torch
.
tensor
([[
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
]]))
def
test_pretrain_bond_featurizer
():
mol
=
test_mol3
()
test_featurizer
=
PretrainBondFeaturizer
()
feats
=
test_featurizer
(
mol
)
assert
torch
.
allclose
(
feats
[
'bond_type'
].
nonzero
(),
torch
.
tensor
([[
0
],
[
1
],
[
6
],
[
7
],
[
10
],
[
11
],
[
14
],
[
15
],
[
16
],
[
17
],
[
18
],
[
19
],
[
20
],
[
21
]]))
assert
torch
.
allclose
(
feats
[
'bond_direction_type'
].
nonzero
(),
torch
.
tensor
([[
4
],
[
5
],
[
8
],
[
9
]]))
test_featurizer
=
PretrainBondFeaturizer
(
self_loop
=
False
)
feats
=
test_featurizer
(
mol
)
assert
torch
.
allclose
(
feats
[
'bond_type'
].
nonzero
(),
torch
.
tensor
([[
0
],
[
1
],
[
6
],
[
7
],
[
10
],
[
11
]]))
assert
torch
.
allclose
(
feats
[
'bond_direction_type'
].
nonzero
(),
torch
.
tensor
([[
4
],
[
5
],
[
8
],
[
9
]]))
if
__name__
==
'__main__'
:
test_one_hot_encoding
()
test_atom_type_one_hot
()
test_atomic_number_one_hot
()
test_atomic_number
()
test_atom_degree_one_hot
()
test_atom_degree
()
test_atom_total_degree_one_hot
()
test_atom_total_degree
()
test_atom_explicit_valence
()
test_atom_implicit_valence_one_hot
()
test_atom_implicit_valence
()
test_atom_hybridization_one_hot
()
test_atom_total_num_H_one_hot
()
test_atom_total_num_H
()
test_atom_formal_charge_one_hot
()
test_atom_formal_charge
()
test_atom_num_radical_electrons_one_hot
()
test_atom_num_radical_electrons
()
test_atom_is_aromatic_one_hot
()
test_atom_is_aromatic
()
test_atom_is_in_ring_one_hot
()
test_atom_is_in_ring
()
test_atom_chiral_tag_one_hot
()
test_atom_mass
()
test_concat_featurizer
()
test_base_atom_featurizer
()
test_canonical_atom_featurizer
()
test_weave_atom_featurizer
()
test_pretrain_atom_featurizer
()
test_bond_type_one_hot
()
test_bond_is_conjugated_one_hot
()
test_bond_is_conjugated
()
test_bond_is_in_ring_one_hot
()
test_bond_is_in_ring
()
test_bond_stereo_one_hot
()
test_bond_direction_one_hot
()
test_base_bond_featurizer
()
test_canonical_bond_featurizer
()
test_weave_edge_featurizer
()
test_pretrain_bond_featurizer
()
apps/life_sci/tests/utils/test_mol_to_graph.py
deleted
100644 → 0
View file @
94c67203
import
numpy
as
np
import
torch
from
dgllife.utils.featurizers
import
*
from
dgllife.utils.mol_to_graph
import
*
from
rdkit
import
Chem
from
rdkit.Chem
import
AllChem
test_smiles1
=
'CCO'
test_smiles2
=
'Fc1ccccc1'
test_smiles3
=
'[CH2:1]([CH3:2])[N:3]1[CH2:4][CH2:5][C:6]([CH3:16])'
\
'([CH3:17])[c:7]2[cH:8][cH:9][c:10]([N+:13]([O-:14])=[O:15])'
\
'[cH:11][c:12]21.[CH3:18][CH2:19][O:20][C:21]([CH3:22])=[O:23]'
class
TestAtomFeaturizer
(
BaseAtomFeaturizer
):
def
__init__
(
self
):
super
(
TestAtomFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
'hv'
:
ConcatFeaturizer
([
atomic_number
])})
class
TestBondFeaturizer
(
BaseBondFeaturizer
):
def
__init__
(
self
):
super
(
TestBondFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
'he'
:
ConcatFeaturizer
([
bond_is_in_ring
])})
def
test_smiles_to_bigraph
():
# Test the case with self loops added.
g1
=
smiles_to_bigraph
(
test_smiles1
,
add_self_loop
=
True
)
src
,
dst
=
g1
.
edges
()
assert
torch
.
allclose
(
src
,
torch
.
LongTensor
([
0
,
2
,
2
,
1
,
0
,
1
,
2
]))
assert
torch
.
allclose
(
dst
,
torch
.
LongTensor
([
2
,
0
,
1
,
2
,
0
,
1
,
2
]))
# Test the case without self loops.
test_node_featurizer
=
TestAtomFeaturizer
()
test_edge_featurizer
=
TestBondFeaturizer
()
g2
=
smiles_to_bigraph
(
test_smiles2
,
add_self_loop
=
False
,
node_featurizer
=
test_node_featurizer
,
edge_featurizer
=
test_edge_featurizer
)
assert
torch
.
allclose
(
g2
.
ndata
[
'hv'
],
torch
.
tensor
([[
9.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
]]))
assert
torch
.
allclose
(
g2
.
edata
[
'he'
],
torch
.
tensor
([[
0.
],
[
0.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
g3
=
smiles_to_bigraph
(
test_smiles3
,
node_featurizer
=
test_node_featurizer
,
canonical_atom_order
=
False
)
assert
torch
.
allclose
(
g3
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
6.
],
[
7.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
7.
],
[
8.
],
[
8.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
8.
],
[
6.
],
[
6.
],
[
8.
]]))
def
test_mol_to_bigraph
():
mol1
=
Chem
.
MolFromSmiles
(
test_smiles1
)
g1
=
mol_to_bigraph
(
mol1
,
add_self_loop
=
True
)
src
,
dst
=
g1
.
edges
()
assert
torch
.
allclose
(
src
,
torch
.
LongTensor
([
0
,
2
,
2
,
1
,
0
,
1
,
2
]))
assert
torch
.
allclose
(
dst
,
torch
.
LongTensor
([
2
,
0
,
1
,
2
,
0
,
1
,
2
]))
# Test the case without self loops.
mol2
=
Chem
.
MolFromSmiles
(
test_smiles2
)
test_node_featurizer
=
TestAtomFeaturizer
()
test_edge_featurizer
=
TestBondFeaturizer
()
g2
=
mol_to_bigraph
(
mol2
,
add_self_loop
=
False
,
node_featurizer
=
test_node_featurizer
,
edge_featurizer
=
test_edge_featurizer
)
assert
torch
.
allclose
(
g2
.
ndata
[
'hv'
],
torch
.
tensor
([[
9.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
]]))
assert
torch
.
allclose
(
g2
.
edata
[
'he'
],
torch
.
tensor
([[
0.
],
[
0.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
mol3
=
Chem
.
MolFromSmiles
(
test_smiles3
)
g3
=
mol_to_bigraph
(
mol3
,
node_featurizer
=
test_node_featurizer
,
canonical_atom_order
=
False
)
assert
torch
.
allclose
(
g3
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
6.
],
[
7.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
7.
],
[
8.
],
[
8.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
8.
],
[
6.
],
[
6.
],
[
8.
]]))
def
test_smiles_to_complete_graph
():
test_node_featurizer
=
TestAtomFeaturizer
()
g1
=
smiles_to_complete_graph
(
test_smiles1
,
add_self_loop
=
False
,
node_featurizer
=
test_node_featurizer
)
src
,
dst
=
g1
.
edges
()
assert
torch
.
allclose
(
src
,
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]))
assert
torch
.
allclose
(
dst
,
torch
.
LongTensor
([
1
,
2
,
0
,
2
,
0
,
1
]))
assert
torch
.
allclose
(
g1
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
8.
],
[
6.
]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
g2
=
smiles_to_complete_graph
(
test_smiles3
,
node_featurizer
=
test_node_featurizer
,
canonical_atom_order
=
False
)
assert
torch
.
allclose
(
g2
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
6.
],
[
7.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
7.
],
[
8.
],
[
8.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
8.
],
[
6.
],
[
6.
],
[
8.
]]))
def
test_mol_to_complete_graph
():
test_node_featurizer
=
TestAtomFeaturizer
()
mol1
=
Chem
.
MolFromSmiles
(
test_smiles1
)
g1
=
mol_to_complete_graph
(
mol1
,
add_self_loop
=
False
,
node_featurizer
=
test_node_featurizer
)
src
,
dst
=
g1
.
edges
()
assert
torch
.
allclose
(
src
,
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]))
assert
torch
.
allclose
(
dst
,
torch
.
LongTensor
([
1
,
2
,
0
,
2
,
0
,
1
]))
assert
torch
.
allclose
(
g1
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
8.
],
[
6.
]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
mol2
=
Chem
.
MolFromSmiles
(
test_smiles3
)
g2
=
mol_to_complete_graph
(
mol2
,
node_featurizer
=
test_node_featurizer
,
canonical_atom_order
=
False
)
assert
torch
.
allclose
(
g2
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
6.
],
[
7.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
7.
],
[
8.
],
[
8.
],
[
6.
],
[
6.
],
[
6.
],
[
6.
],
[
8.
],
[
6.
],
[
6.
],
[
8.
]]))
def
test_k_nearest_neighbors
():
coordinates
=
np
.
array
([[
0.1
,
0.1
,
0.1
],
[
0.2
,
0.1
,
0.1
],
[
0.15
,
0.15
,
0.1
],
[
0.1
,
0.15
,
0.16
],
[
1.2
,
0.1
,
0.1
],
[
1.3
,
0.2
,
0.1
]])
neighbor_cutoff
=
1.
max_num_neighbors
=
2
srcs
,
dsts
,
dists
=
k_nearest_neighbors
(
coordinates
,
neighbor_cutoff
,
max_num_neighbors
)
assert
srcs
==
[
2
,
3
,
2
,
0
,
0
,
1
,
0
,
2
,
1
,
5
,
4
]
assert
dsts
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
,
5
]
assert
dists
==
[
0.07071067811865478
,
0.0781024967590666
,
0.07071067811865483
,
0.1
,
0.07071067811865478
,
0.07071067811865483
,
0.0781024967590666
,
0.0781024967590666
,
1.0
,
0.14142135623730956
,
0.14142135623730956
]
# Test the case where self loops are included
srcs
,
dsts
,
dists
=
k_nearest_neighbors
(
coordinates
,
neighbor_cutoff
,
max_num_neighbors
,
self_loops
=
True
)
assert
srcs
==
[
0
,
2
,
1
,
2
,
2
,
0
,
3
,
0
,
4
,
5
,
4
,
5
]
assert
dsts
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
,
5
,
5
]
assert
dists
==
[
0.0
,
0.07071067811865478
,
0.0
,
0.07071067811865483
,
0.0
,
0.07071067811865478
,
0.0
,
0.0781024967590666
,
0.0
,
0.14142135623730956
,
0.14142135623730956
,
0.0
]
# Test the case where max_num_neighbors is not given
srcs
,
dsts
,
dists
=
k_nearest_neighbors
(
coordinates
,
neighbor_cutoff
=
10.
)
assert
srcs
==
[
1
,
2
,
3
,
4
,
5
,
0
,
2
,
3
,
4
,
5
,
0
,
1
,
3
,
4
,
5
,
0
,
1
,
2
,
4
,
5
,
0
,
1
,
2
,
3
,
5
,
0
,
1
,
2
,
3
,
4
]
assert
dsts
==
[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
4
,
4
,
4
,
4
,
4
,
5
,
5
,
5
,
5
,
5
]
assert
dists
==
[
0.1
,
0.07071067811865478
,
0.0781024967590666
,
1.1
,
1.2041594578792296
,
0.1
,
0.07071067811865483
,
0.12688577540449525
,
1.0
,
1.104536101718726
,
0.07071067811865478
,
0.07071067811865483
,
0.0781024967590666
,
1.0511898020814319
,
1.151086443322134
,
0.0781024967590666
,
0.12688577540449525
,
0.0781024967590666
,
1.1027692415006867
,
1.202538980657176
,
1.1
,
1.0
,
1.0511898020814319
,
1.1027692415006867
,
0.14142135623730956
,
1.2041594578792296
,
1.104536101718726
,
1.151086443322134
,
1.202538980657176
,
0.14142135623730956
]
def
test_smiles_to_nearest_neighbor_graph
():
mol
=
Chem
.
MolFromSmiles
(
test_smiles1
)
AllChem
.
EmbedMolecule
(
mol
)
coordinates
=
mol
.
GetConformers
()[
0
].
GetPositions
()
# Test node featurizer
test_node_featurizer
=
TestAtomFeaturizer
()
g
=
smiles_to_nearest_neighbor_graph
(
test_smiles1
,
coordinates
,
neighbor_cutoff
=
10
,
node_featurizer
=
test_node_featurizer
)
assert
torch
.
allclose
(
g
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
8.
],
[
6.
]]))
assert
g
.
number_of_edges
()
==
6
assert
'dist'
not
in
g
.
edata
# Test self loops
g
=
smiles_to_nearest_neighbor_graph
(
test_smiles1
,
coordinates
,
neighbor_cutoff
=
10
,
add_self_loop
=
True
)
assert
g
.
number_of_edges
()
==
9
# Test max_num_neighbors
g
=
smiles_to_nearest_neighbor_graph
(
test_smiles1
,
coordinates
,
neighbor_cutoff
=
10
,
max_num_neighbors
=
1
,
add_self_loop
=
True
)
assert
g
.
number_of_edges
()
==
3
# Test pairwise distances
g
=
smiles_to_nearest_neighbor_graph
(
test_smiles1
,
coordinates
,
neighbor_cutoff
=
10
,
keep_dists
=
True
)
assert
'dist'
in
g
.
edata
coordinates
=
torch
.
from_numpy
(
coordinates
)
srcs
,
dsts
=
g
.
edges
()
dist
=
torch
.
norm
(
coordinates
[
srcs
]
-
coordinates
[
dsts
],
dim
=
1
,
p
=
2
).
float
().
reshape
(
-
1
,
1
)
assert
torch
.
allclose
(
dist
,
g
.
edata
[
'dist'
])
def
test_mol_to_nearest_neighbor_graph
():
mol
=
Chem
.
MolFromSmiles
(
test_smiles1
)
AllChem
.
EmbedMolecule
(
mol
)
coordinates
=
mol
.
GetConformers
()[
0
].
GetPositions
()
# Test node featurizer
test_node_featurizer
=
TestAtomFeaturizer
()
g
=
mol_to_nearest_neighbor_graph
(
mol
,
coordinates
,
neighbor_cutoff
=
10
,
node_featurizer
=
test_node_featurizer
)
assert
torch
.
allclose
(
g
.
ndata
[
'hv'
],
torch
.
tensor
([[
6.
],
[
8.
],
[
6.
]]))
assert
g
.
number_of_edges
()
==
6
assert
'dist'
not
in
g
.
edata
# Test self loops
g
=
mol_to_nearest_neighbor_graph
(
mol
,
coordinates
,
neighbor_cutoff
=
10
,
add_self_loop
=
True
)
assert
g
.
number_of_edges
()
==
9
# Test max_num_neighbors
g
=
mol_to_nearest_neighbor_graph
(
mol
,
coordinates
,
neighbor_cutoff
=
10
,
max_num_neighbors
=
1
,
add_self_loop
=
True
)
assert
g
.
number_of_edges
()
==
3
# Test pairwise distances
g
=
mol_to_nearest_neighbor_graph
(
mol
,
coordinates
,
neighbor_cutoff
=
10
,
keep_dists
=
True
)
assert
'dist'
in
g
.
edata
coordinates
=
torch
.
from_numpy
(
coordinates
)
srcs
,
dsts
=
g
.
edges
()
dist
=
torch
.
norm
(
coordinates
[
srcs
]
-
coordinates
[
dsts
],
dim
=
1
,
p
=
2
).
float
().
reshape
(
-
1
,
1
)
assert
torch
.
allclose
(
dist
,
g
.
edata
[
'dist'
])
if
__name__
==
'__main__'
:
test_smiles_to_bigraph
()
test_mol_to_bigraph
()
test_smiles_to_complete_graph
()
test_mol_to_complete_graph
()
test_k_nearest_neighbors
()
test_smiles_to_nearest_neighbor_graph
()
test_mol_to_nearest_neighbor_graph
()
apps/life_sci/tests/utils/test_rdkit_utils.py
deleted
100644 → 0
View file @
94c67203
import
numpy
as
np
import
os
import
shutil
from
dgl.data.utils
import
download
,
_get_dgl_url
,
extract_archive
from
dgllife.utils.rdkit_utils
import
get_mol_3d_coordinates
,
load_molecule
from
rdkit
import
Chem
from
rdkit.Chem
import
AllChem
def
test_get_mol_3D_coordinates
():
mol
=
Chem
.
MolFromSmiles
(
'CCO'
)
# Test the case when conformation does not exist
assert
get_mol_3d_coordinates
(
mol
)
is
None
# Test the case when conformation exists
AllChem
.
EmbedMolecule
(
mol
)
AllChem
.
MMFFOptimizeMolecule
(
mol
)
coords
=
get_mol_3d_coordinates
(
mol
)
assert
isinstance
(
coords
,
np
.
ndarray
)
assert
coords
.
shape
==
(
mol
.
GetNumAtoms
(),
3
)
def
remove_dir
(
dir
):
if
os
.
path
.
isdir
(
dir
):
try
:
shutil
.
rmtree
(
dir
)
except
OSError
:
pass
def
test_load_molecule
():
remove_dir
(
'tmp1'
)
remove_dir
(
'tmp2'
)
url
=
_get_dgl_url
(
'dgllife/example_mols.tar.gz'
)
local_path
=
'tmp1/example_mols.tar.gz'
download
(
url
,
path
=
local_path
)
extract_archive
(
local_path
,
'tmp2'
)
load_molecule
(
'tmp2/example_mols/example.sdf'
)
load_molecule
(
'tmp2/example_mols/example.mol2'
,
use_conformation
=
False
,
sanitize
=
True
)
load_molecule
(
'tmp2/example_mols/example.pdbqt'
,
calc_charges
=
True
)
mol
,
_
=
load_molecule
(
'tmp2/example_mols/example.pdb'
,
remove_hs
=
True
)
assert
mol
.
GetNumAtoms
()
==
mol
.
GetNumHeavyAtoms
()
remove_dir
(
'tmp1'
)
remove_dir
(
'tmp2'
)
if
__name__
==
'__main__'
:
test_get_mol_3D_coordinates
()
test_load_molecule
()
apps/life_sci/tests/utils/test_splitters.py
deleted
100644 → 0
View file @
94c67203
import
torch
from
dgllife.utils.splitters
import
*
from
rdkit
import
Chem
class
TestDataset
(
object
):
def
__init__
(
self
):
self
.
smiles
=
[
'CCO'
,
'C1CCCCC1'
,
'O1CCOCC1'
,
'C1CCCC2C1CCCC2'
,
'N#N'
]
self
.
mols
=
[
Chem
.
MolFromSmiles
(
s
)
for
s
in
self
.
smiles
]
self
.
labels
=
torch
.
arange
(
2
*
len
(
self
.
smiles
)).
reshape
(
len
(
self
.
smiles
),
-
1
)
def
__getitem__
(
self
,
item
):
return
self
.
smiles
[
item
],
self
.
mols
[
item
]
def
__len__
(
self
):
return
len
(
self
.
smiles
)
def
test_consecutive_splitter
(
dataset
):
ConsecutiveSplitter
.
train_val_test_split
(
dataset
)
ConsecutiveSplitter
.
k_fold_split
(
dataset
)
def
test_random_splitter
(
dataset
):
RandomSplitter
.
train_val_test_split
(
dataset
,
random_state
=
0
)
RandomSplitter
.
k_fold_split
(
dataset
)
def
test_molecular_weight_splitter
(
dataset
):
MolecularWeightSplitter
.
train_val_test_split
(
dataset
)
MolecularWeightSplitter
.
k_fold_split
(
dataset
,
mols
=
dataset
.
mols
)
def
test_scaffold_splitter
(
dataset
):
ScaffoldSplitter
.
train_val_test_split
(
dataset
,
include_chirality
=
True
)
ScaffoldSplitter
.
k_fold_split
(
dataset
,
mols
=
dataset
.
mols
)
def
test_single_task_stratified_splitter
(
dataset
):
SingleTaskStratifiedSplitter
.
train_val_test_split
(
dataset
,
dataset
.
labels
,
1
)
SingleTaskStratifiedSplitter
.
k_fold_split
(
dataset
,
dataset
.
labels
,
1
)
if
__name__
==
'__main__'
:
dataset
=
TestDataset
()
test_consecutive_splitter
(
dataset
)
test_random_splitter
(
dataset
)
test_molecular_weight_splitter
(
dataset
)
test_scaffold_splitter
(
dataset
)
test_single_task_stratified_splitter
(
dataset
)
docs/source/api/python/data.rst
View file @
36c7b771
...
...
@@ -129,160 +129,3 @@ Protein-Protein Interaction dataset
.. autoclass:: PPIDataset
:members: __getitem__, __len__
Molecular Graphs
----------------
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
Data Loading and Processing Utils
`````````````````````````````````
We adapt several utilities for processing molecules from
`DeepChem <https://github.com/deepchem/deepchem/blob/master/deepchem>`__.
.. autosummary::
:toctree: ../../generated/
chem.add_hydrogens_to_mol
chem.get_mol_3D_coordinates
chem.load_molecule
chem.multiprocess_load_molecules
Featurization Utils for Single Molecule
```````````````````````````````````````
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds).
General utils:
.. autosummary::
:toctree: ../../generated/
chem.one_hot_encoding
chem.ConcatFeaturizer
chem.ConcatFeaturizer.__call__
Utils for atom featurization:
.. autosummary::
:toctree: ../../generated/
chem.atom_type_one_hot
chem.atomic_number_one_hot
chem.atomic_number
chem.atom_degree_one_hot
chem.atom_degree
chem.atom_total_degree_one_hot
chem.atom_total_degree
chem.atom_implicit_valence_one_hot
chem.atom_implicit_valence
chem.atom_hybridization_one_hot
chem.atom_total_num_H_one_hot
chem.atom_total_num_H
chem.atom_formal_charge_one_hot
chem.atom_formal_charge
chem.atom_num_radical_electrons_one_hot
chem.atom_num_radical_electrons
chem.atom_is_aromatic_one_hot
chem.atom_is_aromatic
chem.atom_chiral_tag_one_hot
chem.atom_mass
chem.BaseAtomFeaturizer
chem.BaseAtomFeaturizer.feat_size
chem.BaseAtomFeaturizer.__call__
chem.CanonicalAtomFeaturizer
Utils for bond featurization:
.. autosummary::
:toctree: ../../generated/
chem.bond_type_one_hot
chem.bond_is_conjugated_one_hot
chem.bond_is_conjugated
chem.bond_is_in_ring_one_hot
chem.bond_is_in_ring
chem.bond_stereo_one_hot
chem.BaseBondFeaturizer
chem.BaseBondFeaturizer.feat_size
chem.BaseBondFeaturizer.__call__
chem.CanonicalBondFeaturizer
Graph Construction for Single Molecule
``````````````````````````````````````
Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects are listed below:
.. autosummary::
:toctree: ../../generated/
chem.mol_to_graph
chem.smiles_to_bigraph
chem.mol_to_bigraph
chem.smiles_to_complete_graph
chem.mol_to_complete_graph
chem.k_nearest_neighbors
Graph Construction and Featurization for Ligand-Protein Complex
```````````````````````````````````````````````````````````````
Constructing DGLHeteroGraphs and featurize for them.
.. autosummary::
:toctree: ../../generated/
chem.ACNN_graph_construction_and_featurization
Dataset Classes
```````````````
If your dataset is stored in a ``.csv`` file, you may find it helpful to use
.. autoclass:: dgl.data.chem.MoleculeCSVDataset
:members: __getitem__, __len__
Currently four datasets are supported:
* Tox21
* TencentAlchemyDataset
* PubChemBioAssayAromaticity
* PDBBind
.. autoclass:: dgl.data.chem.Tox21
:members: __getitem__, __len__, task_pos_weights
.. autoclass:: dgl.data.chem.TencentAlchemyDataset
:members: __getitem__, __len__, set_mean_and_std
.. autoclass:: dgl.data.chem.PubChemBioAssayAromaticity
:members: __getitem__, __len__
.. autoclass:: dgl.data.chem.PDBBind
:members: __getitem__, __len__
Dataset Splitting
`````````````````
We provide support for some common data splitting methods:
* consecutive split
* random split
* molecular weight split
* Bemis-Murcko scaffold split
* single-task-stratified split
.. autoclass:: dgl.data.chem.ConsecutiveSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.RandomSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.MolecularWeightSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.ScaffoldSplitter
:members: train_val_test_split, k_fold_split
.. autoclass:: dgl.data.chem.SingleTaskStratifiedSplitter
:members: train_val_test_split, k_fold_split
docs/source/api/python/model_zoo.rst
deleted
100644 → 0
View file @
94c67203
.. _apimodelzoo:
dgl.model_zoo
==============
.. currentmodule:: dgl.model_zoo
Chemistry
---------
Utils
`````
.. autosummary::
:toctree: ../../generated/
chem.load_pretrained
Property Prediction
```````````````````
Currently supported model architectures:
* GCNClassifier
* GATClassifier
* MPNN
* SchNet
* MGCN
* AttentiveFP
.. autoclass:: dgl.model_zoo.chem.GCNClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.GATClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.MPNNModel
:members: forward
.. autoclass:: dgl.model_zoo.chem.SchNet
:members: forward
.. autoclass:: dgl.model_zoo.chem.MGCNModel
:members: forward
.. autoclass:: dgl.model_zoo.chem.AttentiveFP
:members: forward
Generative Models
`````````````````
Currently supported model architectures:
* DGMG
* JTNN
.. autoclass:: dgl.model_zoo.chem.DGMG
:members: forward
.. autoclass:: dgl.model_zoo.chem.DGLJTNNVAE
:members: forward
Protein Ligand Binding
``````````````````````
Currently supported model architectures:
* ACNN
.. autoclass:: dgl.model_zoo.chem.ACNN
:members: forward
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment