Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1df4991d
Commit
1df4991d
authored
Jan 13, 2022
by
Gustaf Ahdritz
Browse files
DeepSpeed + PL bfloat16 working
parent
02fc4376
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
44 deletions
+64
-44
openfold/model/msa.py
openfold/model/msa.py
+7
-7
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+2
-2
openfold/model/primitives.py
openfold/model/primitives.py
+21
-11
openfold/model/structure_module.py
openfold/model/structure_module.py
+13
-11
openfold/model/template.py
openfold/model/template.py
+4
-5
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+3
-3
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+3
-3
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+11
-2
No files found.
openfold/model/msa.py
View file @
1df4991d
...
...
@@ -149,14 +149,14 @@ class MSAAttention(nn.Module):
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
q
,
k
,
v
,
mask_bias
,
z
return
m
,
q
,
k
,
v
,
mask_bias
,
z
checkpoint_fn
=
get_checkpoint_fn
()
if
(
checkpoint
):
q
,
k
,
v
,
mask_bias
,
z
=
checkpoint_fn
(
_get_qkv
,
m
,
z
)
if
(
torch
.
is_grad_enabled
()
and
checkpoint
):
m
,
q
,
k
,
v
,
mask_bias
,
z
=
checkpoint_fn
(
_get_qkv
,
m
,
z
)
else
:
q
,
k
,
v
,
mask_bias
,
z
=
_get_qkv
(
m
,
z
)
m
,
q
,
k
,
v
,
mask_bias
,
z
=
_get_qkv
(
m
,
z
)
o
=
_attention_chunked_trainable
(
query
=
q
,
...
...
@@ -168,7 +168,7 @@ class MSAAttention(nn.Module):
checkpoint
=
checkpoint
,
)
if
(
checkpoint
):
if
(
torch
.
is_grad_enabled
()
and
checkpoint
):
# Storing an additional m here is far from ideal
m
=
checkpoint_fn
(
self
.
mha
.
_wrap_up
,
o
,
m
)
else
:
...
...
openfold/model/pair_transition.py
View file @
1df4991d
...
...
@@ -17,7 +17,7 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
chunk_layer
...
...
@@ -40,7 +40,7 @@ class PairTransition(nn.Module):
self
.
c_z
=
c_z
self
.
n
=
n
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm
=
LayerNorm
(
self
.
c_z
)
self
.
linear_1
=
Linear
(
self
.
c_z
,
self
.
n
*
self
.
c_z
,
init
=
"relu"
)
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
...
...
openfold/model/primitives.py
View file @
1df4991d
...
...
@@ -179,7 +179,7 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
d
=
x
.
dtype
if
(
d
==
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
x
,
...
...
@@ -189,27 +189,34 @@ class LayerNorm(nn.Module):
self
.
eps
)
elif
(
d
==
torch
.
bfloat16
):
raise
NotImplementedError
out
=
nn
.
functional
.
layer_norm
(
x
,
self
.
c_in
,
self
.
weight
,
self
.
bias
,
self
.
eps
,
)
return
out
def
softmax
(
t
,
dim
=-
1
)
:
@
torch
.
jit
.
ignore
def
softmax
(
t
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d
=
t
.
dtype
if
(
d
==
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
elif
(
d
==
torch
.
bfloat16
):
raise
NotImplementedError
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
return
s
def
_attention
(
query
,
key
,
value
,
biases
):
#@torch.jit.script
def
_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# [*, H, Q, C_hidden]
query
=
permute_final_dims
(
query
,
(
1
,
0
,
2
))
...
...
@@ -225,7 +232,7 @@ def _attention(query, key, value, biases):
for
b
in
biases
:
a
+=
b
a
=
softmax
(
a
,
dim
=
-
1
)
a
=
softmax
(
a
,
-
1
)
# [*, H, Q, C_hidden]
a
=
torch
.
matmul
(
a
,
value
)
...
...
@@ -354,7 +361,9 @@ class Attention(nn.Module):
def
_prep_qkv
(
self
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
kv_x
)
...
...
@@ -375,6 +384,7 @@ class Attention(nn.Module):
)
->
torch
.
Tensor
:
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
...
...
openfold/model/structure_module.py
View file @
1df4991d
...
...
@@ -18,7 +18,7 @@ import torch
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
from
openfold.model.primitives
import
Linear
,
ipa_point_weights_init_
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
a
=
a
*
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
=
a
+
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
*
=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+
=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
...
...
@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module):
# Compute output
################
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
.
transpose
(
-
2
,
-
3
)).
transpose
(
-
2
,
-
3
)
o
=
torch
.
matmul
(
a
,
v
.
transpose
(
-
2
,
-
3
).
to
(
dtype
=
a
.
dtype
)
).
transpose
(
-
2
,
-
3
)
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
...
...
@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module):
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
)
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
)
)
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
...
...
@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module):
s
=
self
.
linear_out
(
torch
.
cat
(
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
)
)
.
to
(
dtype
=
z
.
dtype
)
)
return
s
...
...
@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module):
self
.
layers
.
append
(
l
)
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c
)
self
.
layer_norm
=
LayerNorm
(
self
.
c
)
def
forward
(
self
,
s
):
for
l
in
self
.
layers
:
...
...
@@ -534,8 +536,8 @@ class StructureModule(nn.Module):
self
.
atom_mask
=
None
self
.
lit_positions
=
None
self
.
layer_norm_s
=
nn
.
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_s
=
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
linear_in
=
Linear
(
self
.
c_s
,
self
.
c_s
)
...
...
@@ -551,7 +553,7 @@ class StructureModule(nn.Module):
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
layer_norm_ipa
=
nn
.
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_ipa
=
LayerNorm
(
self
.
c_s
)
self
.
transition
=
StructureModuleTransition
(
self
.
c_s
,
...
...
openfold/model/template.py
View file @
1df4991d
...
...
@@ -19,7 +19,7 @@ from typing import Optional, List
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.model.dropout
import
(
DropoutRowwise
,
DropoutColumnwise
,
...
...
@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module):
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
z
,
"k_x"
:
t
,
"v_x"
:
t
,
"kv_x"
:
t
,
"biases"
:
biases
,
}
return
chunk_layer
(
...
...
@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module):
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
else
:
z
=
self
.
mha
(
q_x
=
z
,
k
_x
=
t
,
v_x
=
t
,
biases
=
biases
)
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
...
...
@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module):
)
self
.
blocks
.
append
(
block
)
self
.
layer_norm
=
nn
.
LayerNorm
(
c_t
)
self
.
layer_norm
=
LayerNorm
(
c_t
)
def
forward
(
self
,
...
...
openfold/model/triangular_attention.py
View file @
1df4991d
...
...
@@ -20,7 +20,7 @@ from typing import Optional, List
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
...
...
@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module):
self
.
starting
=
starting
self
.
inf
=
inf
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm
=
LayerNorm
(
self
.
c_in
)
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
...
...
@@ -116,7 +116,7 @@ class TriangleAttention(nn.Module):
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
else
:
x
=
self
.
mha
(
q_x
=
x
,
k
_x
=
x
,
v_x
=
x
,
biases
=
biases
)
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
)
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/triangular_multiplicative_update.py
View file @
1df4991d
...
...
@@ -19,7 +19,7 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
permute_final_dims
...
...
@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module):
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
self
.
layer_norm_in
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_out
=
nn
.
LayerNorm
(
self
.
c_hidden
)
self
.
layer_norm_in
=
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_out
=
LayerNorm
(
self
.
c_hidden
)
self
.
sigmoid
=
nn
.
Sigmoid
()
...
...
openfold/utils/rigid_utils.py
View file @
1df4991d
...
...
@@ -26,7 +26,7 @@ def rot_matmul(
)
->
torch
.
Tensor
:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid
transfer to low-precision tensor cores
.
out by hand to avoid
AMP downcasting
.
Args:
a: [*, 3, 3] left multiplicand
...
...
@@ -86,7 +86,7 @@ def rot_vec_mul(
)
->
torch
.
Tensor
:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to
low-precision tensor cores
.
to
avoid AMP downcasting
.
Args:
r: [*, 3, 3] rotation matrices
...
...
@@ -323,6 +323,12 @@ class Rotation:
"Incorrectly shaped rotation matrix or quaternion"
)
# Force full-precision
if
(
quats
is
not
None
):
quats
=
quats
.
to
(
dtype
=
torch
.
float32
)
if
(
rot_mats
is
not
None
):
rot_mats
=
rot_mats
.
to
(
dtype
=
torch
.
float32
)
if
(
quats
is
not
None
and
normalize_quats
):
quats
=
quats
/
torch
.
linalg
.
norm
(
quats
,
dim
=-
1
,
keepdim
=
True
)
...
...
@@ -857,6 +863,9 @@ class Rigid:
(
rots
.
device
!=
trans
.
device
)):
raise
ValueError
(
"Rots and trans incompatible"
)
# Force full precision. Happens to the rotations automatically.
trans
=
trans
.
to
(
dtype
=
torch
.
float32
)
self
.
_rots
=
rots
self
.
_trans
=
trans
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment