Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a17a9777
Commit
a17a9777
authored
Oct 06, 2022
by
Gustaf Ahdritz
Browse files
Fix precision bug
parent
48670cfc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
4 deletions
+8
-4
openfold/model/heads.py
openfold/model/heads.py
+3
-3
openfold/model/primitives.py
openfold/model/primitives.py
+2
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+1
-0
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+2
-0
No files found.
openfold/model/heads.py
View file @
a17a9777
...
@@ -151,7 +151,6 @@ class DistogramHead(nn.Module):
...
@@ -151,7 +151,6 @@ class DistogramHead(nn.Module):
return
logits
return
logits
def
forward
(
self
,
z
):
def
forward
(
self
,
z
):
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
...
@@ -159,6 +158,7 @@ class DistogramHead(nn.Module):
...
@@ -159,6 +158,7 @@ class DistogramHead(nn.Module):
else
:
else
:
return
self
.
_forward
(
z
)
return
self
.
_forward
(
z
)
class
TMScoreHead
(
nn
.
Module
):
class
TMScoreHead
(
nn
.
Module
):
"""
"""
For use in computation of TM-score, subsection 1.9.7
For use in computation of TM-score, subsection 1.9.7
...
...
openfold/model/primitives.py
View file @
a17a9777
...
@@ -480,8 +480,9 @@ class Attention(nn.Module):
...
@@ -480,8 +480,9 @@ class Attention(nn.Module):
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
:
if
float16_enabled
and
torch
.
is_autocast_enabled
()
:
use_memory_efficient_kernel
=
False
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
):
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
raise
ValueError
(
...
...
openfold/model/structure_module.py
View file @
a17a9777
...
@@ -324,6 +324,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -324,6 +324,7 @@ class InvariantPointAttention(nn.Module):
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
)
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
...
...
openfold/model/triangular_multiplicative_update.py
View file @
a17a9777
...
@@ -391,12 +391,14 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -391,12 +391,14 @@ class TriangleMultiplicativeUpdate(nn.Module):
b
=
mask
b
=
mask
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
b
=
b
*
self
.
linear_b_p
(
z
)
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment