Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b40fab25
Commit
b40fab25
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Add offloading to structure module
parent
7d9e6830
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
15 deletions
+51
-15
openfold/model/structure_module.py
openfold/model/structure_module.py
+51
-15
No files found.
openfold/model/structure_module.py
View file @
b40fab25
...
@@ -19,7 +19,7 @@ from operator import mul
...
@@ -19,7 +19,7 @@ from operator import mul
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Sequence
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
...
@@ -229,9 +229,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -229,9 +229,11 @@ class InvariantPointAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
s
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
,
r
:
Rigid
,
r
:
Rigid
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -247,6 +249,10 @@ class InvariantPointAttention(nn.Module):
...
@@ -247,6 +249,10 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation update
[*, N_res, C_s] single representation update
"""
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
z
=
[
z
]
#######################################
#######################################
# Generate scalar and point activations
# Generate scalar and point activations
...
@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module):
...
@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
# Compute attention scores
##########################
##########################
# [*, N_res, N_res, H]
# [*, N_res, N_res, H]
b
=
self
.
linear_b
(
z
)
b
=
self
.
linear_b
(
z
[
0
])
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
...
@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3]
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
[
0
]
.
to
(
dtype
=
a
.
dtype
))
# [*, N_res, H * C_z]
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
...
@@ -402,9 +414,9 @@ class InvariantPointAttention(nn.Module):
...
@@ -402,9 +414,9 @@ class InvariantPointAttention(nn.Module):
s
=
self
.
linear_out
(
s
=
self
.
linear_out
(
torch
.
cat
(
torch
.
cat
(
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
[
0
]
.
dtype
)
)
)
return
s
return
s
...
@@ -604,17 +616,19 @@ class StructureModule(nn.Module):
...
@@ -604,17 +616,19 @@ class StructureModule(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
s
,
evoformer_output_dict
,
z
,
aatype
,
aatype
,
mask
=
None
,
mask
=
None
,
_offload_inference
=
False
,
):
):
"""
"""
Args:
Args:
s:
evoformer_output_dict:
[*, N_res, C_s] single representation
Dictionary containing:
z:
"single":
[*, N_res, N_res, C_z] pair representation
[*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
aatype:
[*, N_res] amino acid indices
[*, N_res] amino acid indices
mask:
mask:
...
@@ -626,11 +640,19 @@ class StructureModule(nn.Module):
...
@@ -626,11 +640,19 @@ class StructureModule(nn.Module):
# [*, N]
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
s
=
evoformer_output_dict
[
"single"
]
# [*, N, C_s]
# [*, N, C_s]
s
=
self
.
layer_norm_s
(
s
)
s
=
self
.
layer_norm_s
(
s
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
z_reference_list
=
None
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
# [*, N, C_s]
# [*, N, C_s]
s_initial
=
s
s_initial
=
s
...
@@ -647,11 +669,18 @@ class StructureModule(nn.Module):
...
@@ -647,11 +669,18 @@ class StructureModule(nn.Module):
outputs
=
[]
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
# [*, N]
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
...
@@ -698,6 +727,13 @@ class StructureModule(nn.Module):
...
@@ -698,6 +727,13 @@ class StructureModule(nn.Module):
rigids
=
rigids
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
del
z
,
z_reference_list
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
(
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
outputs
[
"single"
]
=
s
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment