Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d5480cbe
Commit
d5480cbe
authored
Oct 29, 2021
by
Gustaf Ahdritz
Browse files
Fix horrible in-place operation bug
parent
2faff451
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
18 deletions
+25
-18
deepspeed_config.json
deepspeed_config.json
+1
-1
openfold/model/template.py
openfold/model/template.py
+14
-11
train_openfold.py
train_openfold.py
+10
-6
No files found.
deepspeed_config.json
View file @
d5480cbe
...
...
@@ -23,7 +23,7 @@
"opt_level"
:
"O2"
},
"zero_optimization"
:
{
"stage"
:
1
"stage"
:
2
},
"activation_checkpointing"
:
{
"partition_activations"
:
true
,
...
...
openfold/model/template.py
View file @
d5480cbe
...
...
@@ -44,7 +44,6 @@ class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def
__init__
(
self
,
c_t
,
c_z
,
c_hidden
,
no_heads
,
inf
,
**
kwargs
):
"""
Args:
...
...
@@ -85,9 +84,6 @@ class TemplatePointwiseAttention(nn.Module):
[*, N_res, N_res, C_z] pair embedding update
"""
if
template_mask
is
None
:
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# 1, but not sure enough to remove it. It's nice to have, I guess.
template_mask
=
t
.
new_ones
(
t
.
shape
[:
-
3
])
bias
=
self
.
inf
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
)
...
...
@@ -174,11 +170,16 @@ class TemplatePairStackBlock(nn.Module):
)
def
forward
(
self
,
z
,
mask
,
chunk_size
,
_mask_trans
=
True
):
for
templ_idx
in
range
(
z
.
shape
[
-
4
]):
# Select a single template at a time
single
=
z
[...,
templ_idx
:
templ_idx
+
1
,
:,
:,
:]
single_mask
=
mask
[...,
templ_idx
:
templ_idx
+
1
,
:,
:]
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
...
...
@@ -210,7 +211,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size
=
chunk_size
,
mask
=
single_mask
if
_mask_trans
else
None
)
z
[...,
templ_idx
:
templ_idx
+
1
,
:,
:,
:]
=
single
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
...
...
@@ -219,7 +223,6 @@ class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def
__init__
(
self
,
c_t
,
...
...
train_openfold.py
View file @
d5480cbe
...
...
@@ -61,6 +61,10 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
loss
=
None
logging
.
warning
(
"loss is NaN. Skipping example..."
)
return
{
"loss"
:
loss
}
def
validation_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -113,12 +117,12 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
#
data_module = DummyDataLoader("batch.pickle")
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
data_module
=
DummyDataLoader
(
"batch.pickle"
)
#
data_module = OpenFoldDataModule(
#
config=config.data,
#
batch_seed=args.seed,
#
**vars(args)
#
)
data_module
.
prepare_data
()
data_module
.
setup
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment