Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2faff451
Commit
2faff451
authored
Oct 28, 2021
by
Gustaf Ahdritz
Browse files
Make template code use less memory during inference
parent
407d9924
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
59 additions
and
33 deletions
+59
-33
openfold/config.py
openfold/config.py
+1
-2
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+3
-1
openfold/model/model.py
openfold/model/model.py
+8
-13
openfold/model/template.py
openfold/model/template.py
+45
-14
tests/test_template.py
tests/test_template.py
+2
-3
No files found.
openfold/config.py
View file @
2faff451
...
...
@@ -231,8 +231,7 @@ config = mlc.ConfigDict(
# Recurring FieldReferences that can be changed globally here
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"train_chunk_size"
:
None
,
"eval_chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_t"
:
c_t
,
...
...
openfold/data/data_pipeline.py
View file @
2faff451
...
...
@@ -210,7 +210,9 @@ class AlignmentRunner:
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
],
n_cpu
=
no_cpus
,
)
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
...
...
openfold/model/model.py
View file @
2faff451
...
...
@@ -106,7 +106,7 @@ class AlphaFold(nn.Module):
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
...
@@ -146,18 +146,20 @@ class AlphaFold(nn.Module):
template_embeds
,
)
# [*, N, N, C_z]
# [*,
S_t,
N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
),
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
],
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
...
@@ -170,12 +172,6 @@ class AlphaFold(nn.Module):
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
# Establish constants
chunk_size
=
(
self
.
globals
.
train_chunk_size
if
self
.
training
else
self
.
globals
.
eval_chunk_size
)
# Primary output dictionary
outputs
=
{}
...
...
@@ -251,7 +247,6 @@ class AlphaFold(nn.Module):
z
,
pair_mask
,
no_batch_dims
,
chunk_size
,
)
# [*, N, N, C_z]
...
...
@@ -281,7 +276,7 @@ class AlphaFold(nn.Module):
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
],
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -295,7 +290,7 @@ class AlphaFold(nn.Module):
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
openfold/model/template.py
View file @
2faff451
...
...
@@ -174,17 +174,43 @@ class TemplatePairStackBlock(nn.Module):
)
def
forward
(
self
,
z
,
mask
,
chunk_size
,
_mask_trans
=
True
):
z
=
z
+
self
.
dropout_row
(
self
.
tri_att_start
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
for
templ_idx
in
range
(
z
.
shape
[
-
4
]):
# Select a single template at a time
single
=
z
[...,
templ_idx
:
templ_idx
+
1
,
:,
:,
:]
single_mask
=
mask
[...,
templ_idx
:
templ_idx
+
1
,
:,
:]
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
z
=
z
+
self
.
dropout_col
(
self
.
tri_att_end
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
)
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_out
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_in
(
z
,
mask
=
mask
))
z
=
z
+
self
.
pair_transition
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
if
_mask_trans
else
None
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
pair_transition
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
if
_mask_trans
else
None
)
z
[...,
templ_idx
:
templ_idx
+
1
,
:,
:,
:]
=
single
return
z
...
...
@@ -254,12 +280,17 @@ class TemplatePairStack(nn.Module):
"""
Args:
t:
[*, N_res, N_res, C_t] template embedding
[*,
N_templ,
N_res, N_res, C_t] template embedding
mask:
[*, N_res, N_res] mask
[*,
N_templ,
N_res, N_res] mask
Returns:
[*, N_res, N_res, C_t] template embedding update
[*,
N_templ,
N_res, N_res, C_t] template embedding update
"""
if
(
mask
.
shape
[
-
3
]
==
1
):
expand_idx
=
list
(
mask
.
shape
)
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
(
t
,)
=
checkpoint_blocks
(
blocks
=
[
partial
(
...
...
tests/test_template.py
View file @
2faff451
...
...
@@ -133,8 +133,8 @@ class TestTemplatePairStack(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
_mask_trans
=
False
,
).
cpu
()
...
...
@@ -182,7 +182,6 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
None
,
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment