Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
96ca6460
Commit
96ca6460
authored
Jun 17, 2022
by
Gustaf Ahdritz
Browse files
Scale back chunk size tuning a little
parent
f8f74006
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
8 deletions
+20
-8
openfold/model/evoformer.py
openfold/model/evoformer.py
+7
-4
openfold/model/template.py
openfold/model/template.py
+13
-4
No files found.
openfold/model/evoformer.py
View file @
96ca6460
...
...
@@ -332,6 +332,9 @@ class EvoformerBlock(nn.Module):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
print
(
chunk_size
)
print
(
_attn_chunk_size
)
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
...
...
@@ -653,7 +656,7 @@ class EvoformerStack(nn.Module):
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
2
),
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
...
...
@@ -783,7 +786,7 @@ class ExtraMSAStack(nn.Module):
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
2
),
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
...
...
openfold/model/template.py
View file @
96ca6460
...
...
@@ -201,7 +201,11 @@ class TemplatePairStackBlock(nn.Module):
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_inplace
:
bool
=
False
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
...
...
@@ -216,7 +220,7 @@ class TemplatePairStackBlock(nn.Module):
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
chunk_size
=
_attn_
chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
)
...
...
@@ -228,7 +232,7 @@ class TemplatePairStackBlock(nn.Module):
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
chunk_size
=
_attn_
chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
)
...
...
@@ -375,12 +379,17 @@ class TemplatePairStack(nn.Module):
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
tuned_
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
t
,),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
,
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
t
,
=
checkpoint_blocks
(
blocks
=
blocks
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment