Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9082c254
Unverified
Commit
9082c254
authored
Sep 27, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Sep 27, 2022
Browse files
Merge pull request #222 from beiwang2003/main
FP16 fixes
parents
499b9a84
4d5fa31c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
9 deletions
+48
-9
openfold/model/heads.py
openfold/model/heads.py
+11
-3
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+16
-1
openfold/model/primitives.py
openfold/model/primitives.py
+3
-0
openfold/model/structure_module.py
openfold/model/structure_module.py
+12
-4
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+6
-1
No files found.
openfold/model/heads.py
View file @
9082c254
...
@@ -137,7 +137,7 @@ class DistogramHead(nn.Module):
...
@@ -137,7 +137,7 @@ class DistogramHead(nn.Module):
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
def
forward
(
self
,
z
):
# [*, N, N, C_z]
def
_
forward
(
self
,
z
):
# [*, N, N, C_z]
"""
"""
Args:
Args:
z:
z:
...
@@ -150,6 +150,14 @@ class DistogramHead(nn.Module):
...
@@ -150,6 +150,14 @@ class DistogramHead(nn.Module):
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
return
logits
return
logits
def
forward
(
self
,
z
):
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
z
.
float
())
else
:
return
self
.
_forward
(
z
)
class
TMScoreHead
(
nn
.
Module
):
class
TMScoreHead
(
nn
.
Module
):
"""
"""
...
...
openfold/model/outer_product_mean.py
View file @
9082c254
...
@@ -93,7 +93,7 @@ class OuterProductMean(nn.Module):
...
@@ -93,7 +93,7 @@ class OuterProductMean(nn.Module):
return
outer
return
outer
def
forward
(
self
,
def
_
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
...
@@ -143,3 +143,18 @@ class OuterProductMean(nn.Module):
...
@@ -143,3 +143,18 @@ class OuterProductMean(nn.Module):
outer
=
outer
/
norm
outer
=
outer
/
norm
return
outer
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
m
.
float
(),
mask
,
chunk_size
,
inplace_safe
)
else
:
return
self
.
_forward
(
m
,
mask
,
chunk_size
,
inplace_safe
)
openfold/model/primitives.py
View file @
9082c254
...
@@ -479,6 +479,9 @@ class Attention(nn.Module):
...
@@ -479,6 +479,9 @@ class Attention(nn.Module):
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
:
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
):
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
raise
ValueError
(
...
...
openfold/model/structure_module.py
View file @
9082c254
...
@@ -312,6 +312,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -312,6 +312,14 @@ class InvariantPointAttention(nn.Module):
z
[
0
]
=
z
[
0
].
cpu
()
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
a
=
torch
.
matmul
(
permute_final_dims
(
q
.
float
(),
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
.
float
(),
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
else
:
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
...
...
openfold/model/triangular_multiplicative_update.py
View file @
9082c254
...
@@ -391,6 +391,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -391,6 +391,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
b
=
mask
b
=
mask
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
b
=
b
*
self
.
linear_b_p
(
z
)
float16_enabled
=
(
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
)
if
float16_enabled
and
torch
.
is_autocast_enabled
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment