Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2faff451
"vscode:/vscode.git/clone" did not exist on "fcc28e957f0089783d3685b77a208e131c818c97"
Commit
2faff451
authored
Oct 28, 2021
by
Gustaf Ahdritz
Browse files
Make template code use less memory during inference
parent
407d9924
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
59 additions
and
33 deletions
+59
-33
openfold/config.py
openfold/config.py
+1
-2
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+3
-1
openfold/model/model.py
openfold/model/model.py
+8
-13
openfold/model/template.py
openfold/model/template.py
+45
-14
tests/test_template.py
tests/test_template.py
+2
-3
No files found.
openfold/config.py
View file @
2faff451
...
...
@@ -231,8 +231,7 @@ config = mlc.ConfigDict(
# Recurring FieldReferences that can be changed globally here
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"train_chunk_size"
:
None
,
"eval_chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_t"
:
c_t
,
...
...
openfold/data/data_pipeline.py
View file @
2faff451
...
...
@@ -210,7 +210,9 @@ class AlignmentRunner:
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
],
n_cpu
=
no_cpus
,
)
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
...
...
openfold/model/model.py
View file @
2faff451
...
...
@@ -106,7 +106,7 @@ class AlphaFold(nn.Module):
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
...
@@ -146,18 +146,20 @@ class AlphaFold(nn.Module):
template_embeds
,
)
# [*, N, N, C_z]
# [*,
S_t,
N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
),
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
],
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
...
@@ -170,12 +172,6 @@ class AlphaFold(nn.Module):
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
# Establish constants
chunk_size
=
(
self
.
globals
.
train_chunk_size
if
self
.
training
else
self
.
globals
.
eval_chunk_size
)
# Primary output dictionary
outputs
=
{}
...
...
@@ -251,7 +247,6 @@ class AlphaFold(nn.Module):
z
,
pair_mask
,
no_batch_dims
,
chunk_size
,
)
# [*, N, N, C_z]
...
...
@@ -281,7 +276,7 @@ class AlphaFold(nn.Module):
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
],
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -295,7 +290,7 @@ class AlphaFold(nn.Module):
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
openfold/model/template.py
View file @
2faff451
...
...
@@ -174,17 +174,43 @@ class TemplatePairStackBlock(nn.Module):
)
def
forward
(
self
,
z
,
mask
,
chunk_size
,
_mask_trans
=
True
):
z
=
z
+
self
.
dropout_row
(
self
.
tri_att_start
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
)
z
=
z
+
self
.
dropout_col
(
self
.
tri_att_end
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
)
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_out
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_in
(
z
,
mask
=
mask
))
z
=
z
+
self
.
pair_transition
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
if
_mask_trans
else
None
)
for
templ_idx
in
range
(
z
.
shape
[
-
4
]):
# Select a single template at a time
single
=
z
[...,
templ_idx
:
templ_idx
+
1
,
:,
:,
:]
single_mask
=
mask
[...,
templ_idx
:
templ_idx
+
1
,
:,
:]
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
pair_transition
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
if
_mask_trans
else
None
)
z
[...,
templ_idx
:
templ_idx
+
1
,
:,
:,
:]
=
single
return
z
...
...
@@ -254,12 +280,17 @@ class TemplatePairStack(nn.Module):
"""
Args:
t:
[*, N_res, N_res, C_t] template embedding
[*,
N_templ,
N_res, N_res, C_t] template embedding
mask:
[*, N_res, N_res] mask
[*,
N_templ,
N_res, N_res] mask
Returns:
[*, N_res, N_res, C_t] template embedding update
[*,
N_templ,
N_res, N_res, C_t] template embedding update
"""
if
(
mask
.
shape
[
-
3
]
==
1
):
expand_idx
=
list
(
mask
.
shape
)
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
(
t
,)
=
checkpoint_blocks
(
blocks
=
[
partial
(
...
...
tests/test_template.py
View file @
2faff451
...
...
@@ -133,8 +133,8 @@ class TestTemplatePairStack(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
_mask_trans
=
False
,
).
cpu
()
...
...
@@ -182,7 +182,6 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
None
,
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment