Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
aef97f4b
"...gmock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "cac6c1bb539692045d6a6cf1d562aa56ac33d387"
Commit
aef97f4b
authored
Sep 21, 2022
by
Bei Wang
Browse files
convert suspicious fp16 regions back to fp32
parent
7384e2d6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
9 deletions
+46
-9
openfold/model/heads.py
openfold/model/heads.py
+11
-3
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+16
-1
openfold/model/primitives.py
openfold/model/primitives.py
+1
-0
openfold/model/structure_module.py
openfold/model/structure_module.py
+12
-4
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+6
-1
No files found.
openfold/model/heads.py
View file @
aef97f4b
...
...
@@ -137,7 +137,7 @@ class DistogramHead(nn.Module):
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
def
forward
(
self
,
z
):
# [*, N, N, C_z]
def
_
forward
(
self
,
z
):
# [*, N, N, C_z]
"""
Args:
z:
...
...
@@ -150,6 +150,14 @@ class DistogramHead(nn.Module):
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
return
logits
def
forward
(
self
,
z
):
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
z
.
float
())
else
:
return
self
.
_forward
(
z
)
class
TMScoreHead
(
nn
.
Module
):
"""
...
...
openfold/model/outer_product_mean.py
View file @
aef97f4b
...
...
@@ -93,7 +93,7 @@ class OuterProductMean(nn.Module):
return
outer
def
forward
(
self
,
def
_
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
...
...
@@ -143,3 +143,18 @@ class OuterProductMean(nn.Module):
outer
=
outer
/
norm
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
m
.
float
(),
mask
,
chunk_size
,
inplace_safe
)
else
:
return
self
.
_forward
(
m
,
mask
,
chunk_size
,
inplace_safe
)
openfold/model/primitives.py
View file @
aef97f4b
...
...
@@ -479,6 +479,7 @@ class Attention(nn.Module):
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
# [*, Q, H, C_hidden]
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
...
...
openfold/model/structure_module.py
View file @
aef97f4b
...
...
@@ -312,6 +312,14 @@ class InvariantPointAttention(nn.Module):
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
a
=
torch
.
matmul
(
permute_final_dims
(
q
.
float
(),
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
.
float
(),
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
else
:
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
...
...
openfold/model/triangular_multiplicative_update.py
View file @
aef97f4b
...
...
@@ -391,6 +391,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
b
=
mask
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment