Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e310bba4
"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "beb249e5a5f60afcf151405f3e0444700148859f"
Commit
e310bba4
authored
Apr 05, 2022
by
Gustaf Ahdritz
Browse files
Improve memory management in extra msa stack
parent
cdadff32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
6 deletions
+16
-6
openfold/model/evoformer.py
openfold/model/evoformer.py
+16
-6
No files found.
openfold/model/evoformer.py
View file @
e310bba4
...
@@ -352,20 +352,30 @@ class ExtraMSABlock(nn.Module):
...
@@ -352,20 +352,30 @@ class ExtraMSABlock(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
def
add
(
m1
,
m2
):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if
(
torch
.
is_grad_enabled
()):
m1
=
m1
+
m2
else
:
m1
+=
m2
return
m1
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
self
.
msa_att_row
(
m
.
clone
(),
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
z
=
z
.
clone
(),
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
)
)
)
)
def
fn
(
m
,
z
):
def
fn
(
m
,
z
):
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
,
z
=
self
.
core
(
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment