Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
53bb9c10
Commit
53bb9c10
authored
Jan 10, 2022
by
Gustaf Ahdritz
Browse files
Fix mask casting
parent
a8601529
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
32 deletions
+30
-32
openfold/model/model.py
openfold/model/model.py
+30
-32
No files found.
openfold/model/model.py
View file @
53bb9c10
...
...
@@ -149,7 +149,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
),
pair_mask
.
unsqueeze
(
-
3
)
.
to
(
dtype
=
z
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -158,7 +158,7 @@ class AlphaFold(nn.Module):
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
],
template_mask
=
batch
[
"template_mask"
]
.
to
(
dtype
=
z
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
...
@@ -246,15 +246,13 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
template_mask
=
feats
[
"template_mask"
]
if
(
torch
.
any
(
template_mask
)):
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
z
,
pair_mask
,
pair_mask
.
to
(
dtype
=
z
.
dtype
)
,
no_batch_dims
,
)
...
...
@@ -284,9 +282,9 @@ class AlphaFold(nn.Module):
z
=
self
.
extra_msa_stack
(
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
],
msa_mask
=
feats
[
"extra_msa_mask"
]
.
to
(
dtype
=
a
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
)
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -297,8 +295,8 @@ class AlphaFold(nn.Module):
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
)
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -312,7 +310,7 @@ class AlphaFold(nn.Module):
s
,
z
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
],
mask
=
feats
[
"seq_mask"
]
.
to
(
dtype
=
s
.
dtype
)
,
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment