Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3260242f
Commit
3260242f
authored
Nov 13, 2021
by
Gustaf Ahdritz
Browse files
Add mechanism for TorchScript'ing primitives during inference
parent
1f87c6e5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
2 deletions
+33
-2
openfold/model/primitives.py
openfold/model/primitives.py
+1
-2
openfold/utils/torchscript_utils.py
openfold/utils/torchscript_utils.py
+24
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+8
-0
No files found.
openfold/model/primitives.py
View file @
3260242f
...
...
@@ -166,9 +166,8 @@ class Linear(nn.Linear):
class
Attention
(
nn
.
Module
):
"""
Standard multi-head attention using AlphaFold's default layer
initialization.
initialization.
Allows multiple bias vectors.
"""
def
__init__
(
self
,
c_q
:
int
,
...
...
openfold/utils/torchscript_utils.py
0 → 100644
View file @
3260242f
from
typing
import
Optional
,
Sequence
import
torch
import
torch.nn
as
nn
def
script_submodules_
(
model
:
nn
.
Module
,
types
:
Optional
[
Sequence
[
type
]]
=
None
,
):
"""
Convert all submodules whose types match one of those in the input
list to recursively scripted equivalents in place. To script the entire
model, just call torch.jit.script on it directly.
Args:
model: A torch.nn.Module
types: A list of types of submodules to script
"""
for
name
,
child
in
model
.
named_children
():
if
(
any
(
isinstance
(
child
,
t
)
for
t
in
types
)):
setattr
(
model
,
name
,
torch
.
jit
.
script
(
child
))
else
:
script_submodules_
(
child
,
types
)
run_pretrained_openfold.py
View file @
3260242f
...
...
@@ -31,22 +31,30 @@ import torch
from
openfold.config
import
model_config
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.model.model
import
AlphaFold
from
openfold.model.primitives
import
Attention
,
GlobalAttention
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
openfold.utils.torchscript_utils
import
script_submodules_
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
from
scripts.utils
import
add_data_args
def
script_primitives_
(
model
):
script_submodules_
(
model
,
[
Attention
,
GlobalAttention
])
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
)
script_primitives_
(
model
)
model
=
model
.
to
(
args
.
model_device
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment