Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
aef97f4b
Commit
aef97f4b
authored
Sep 21, 2022
by
Bei Wang
Browse files
convert suspicious fp16 regions back to fp32
parent
7384e2d6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
9 deletions
+46
-9
openfold/model/heads.py
openfold/model/heads.py
+11
-3
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+16
-1
openfold/model/primitives.py
openfold/model/primitives.py
+1
-0
openfold/model/structure_module.py
openfold/model/structure_module.py
+12
-4
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+6
-1
No files found.
openfold/model/heads.py
View file @
aef97f4b
...
@@ -137,7 +137,7 @@ class DistogramHead(nn.Module):
...
@@ -137,7 +137,7 @@ class DistogramHead(nn.Module):
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
def
forward
(
self
,
z
):
# [*, N, N, C_z]
def
_
forward
(
self
,
z
):
# [*, N, N, C_z]
"""
"""
Args:
Args:
z:
z:
...
@@ -149,8 +149,16 @@ class DistogramHead(nn.Module):
...
@@ -149,8 +149,16 @@ class DistogramHead(nn.Module):
logits
=
self
.
linear
(
z
)
logits
=
self
.
linear
(
z
)
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
return
logits
return
logits
def
forward
(
self
,
z
):
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
z
.
float
())
else
:
return
self
.
_forward
(
z
)
class
TMScoreHead
(
nn
.
Module
):
class
TMScoreHead
(
nn
.
Module
):
"""
"""
For use in computation of TM-score, subsection 1.9.7
For use in computation of TM-score, subsection 1.9.7
...
...
openfold/model/outer_product_mean.py
View file @
aef97f4b
...
@@ -93,7 +93,7 @@ class OuterProductMean(nn.Module):
...
@@ -93,7 +93,7 @@ class OuterProductMean(nn.Module):
return
outer
return
outer
def
forward
(
self
,
def
_
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
...
@@ -143,3 +143,18 @@ class OuterProductMean(nn.Module):
...
@@ -143,3 +143,18 @@ class OuterProductMean(nn.Module):
outer
=
outer
/
norm
outer
=
outer
/
norm
return
outer
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
m
.
float
(),
mask
,
chunk_size
,
inplace_safe
)
else
:
return
self
.
_forward
(
m
,
mask
,
chunk_size
,
inplace_safe
)
openfold/model/primitives.py
View file @
aef97f4b
...
@@ -479,6 +479,7 @@ class Attention(nn.Module):
...
@@ -479,6 +479,7 @@ class Attention(nn.Module):
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
):
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
raise
ValueError
(
...
...
openfold/model/structure_module.py
View file @
aef97f4b
...
@@ -312,10 +312,18 @@ class InvariantPointAttention(nn.Module):
...
@@ -312,10 +312,18 @@ class InvariantPointAttention(nn.Module):
z
[
0
]
=
z
[
0
].
cpu
()
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
if
float16_enabled
and
torch
.
is_autocast_enabled
():
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
)
a
=
torch
.
matmul
(
permute_final_dims
(
q
.
float
(),
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
.
float
(),
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
else
:
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
...
...
openfold/model/triangular_multiplicative_update.py
View file @
aef97f4b
...
@@ -391,7 +391,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -391,7 +391,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
b
=
mask
b
=
mask
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
b
=
b
*
self
.
linear_b_p
(
z
)
x
=
self
.
_combine_projections
(
a
,
b
)
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment