Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
5e4d790d
Commit
5e4d790d
authored
Dec 13, 2022
by
zhuww
Browse files
use offload inference on InvariantPointAttention
parent
75b87f63
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
46 deletions
+65
-46
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+23
-30
fastfold/model/nn/structure_module.py
fastfold/model/nn/structure_module.py
+42
-16
No files found.
fastfold/model/hub/alphafold.py
View file @
5e4d790d
...
...
@@ -393,47 +393,39 @@ class AlphaFold(nn.Module):
torch
.
cuda
.
empty_cache
()
if
no_iter
==
3
:
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
del
z
# Predict 3D structure
outputs_sm
=
self
.
structure_module
(
s
,
z
,
outputs
[
"sm"
]
=
self
.
structure_module
(
outputs
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
)
torch
.
cuda
.
empty_cache
()
if
no_iter
==
3
:
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
outputs
[
"sm"
]
=
outputs_sm
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
)
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
else
:
# Save embeddings for use during the next recycling iteration
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
)
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
# [*, N, C_m]
m_1_prev
=
m
[...,
0
,
:,
:]
# [*, N, C_
m
]
m_1
_prev
=
m
[...,
0
,
:,
:
]
# [*, N,
N,
C_
z
]
z
_prev
=
outputs
[
"pair"
]
# [*, N, N, C_z]
z_prev
=
z
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
# [*, N, 3]
x_prev
=
atom14_to_atom37
(
outputs_sm
[
"positions"
][
-
1
],
feats
)
if
no_iter
!=
3
:
return
None
,
m_1_prev
,
z_prev
,
x_prev
else
:
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
...
...
@@ -537,6 +529,7 @@ class AlphaFold(nn.Module):
)
if
cycle_no
!=
3
:
del
outputs
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
...
...
fastfold/model/nn/structure_module.py
View file @
5e4d790d
...
...
@@ -14,11 +14,13 @@
# limitations under the License.
import
math
import
sys
import
torch
import
torch.nn
as
nn
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
fastfold.model.nn.primitives
import
Linear
,
ipa_point_weights_init_
from
fastfold.model.fastnn.kernel
import
LayerNorm
from
fastfold.common.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -292,6 +294,7 @@ class InvariantPointAttention(nn.Module):
z
:
torch
.
Tensor
,
r
:
Union
[
Rigid
,
Rigid3Array
],
mask
:
torch
.
Tensor
,
_offload_inference
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -380,7 +383,11 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b
=
self
.
linear_b
(
z
)
b
=
self
.
linear_b
(
z
[
0
])
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
z
[
0
])
==
2
)
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
...
...
@@ -508,14 +515,17 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
del
a
torch
.
cuda
.
empty_cache
()
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
del
a
torch
.
cuda
.
empty_cache
()
# [*, N_res, C_s]
if
self
.
is_multimer
:
...
...
@@ -526,7 +536,7 @@ class InvariantPointAttention(nn.Module):
s
=
self
.
linear_out
(
torch
.
cat
(
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
[
0
]
.
dtype
)
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -737,8 +747,7 @@ class StructureModule(nn.Module):
def
_forward_monomer
(
self
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
evoformer_output_dict
,
aatype
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Dict
[
str
,
Any
]:
...
...
@@ -755,6 +764,8 @@ class StructureModule(nn.Module):
Returns:
A dictionary of outputs
"""
s
=
evoformer_output_dict
[
"single"
]
if
mask
is
None
:
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
...
@@ -765,9 +776,20 @@ class StructureModule(nn.Module):
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
# inplace z
# z[0] = z[0].contiguous()
# torch.cuda.emtpy_cache()
# z[0] = self.layer_norm_z(z[0])
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
contiguous
()
torch
.
cuda
.
empty_cache
()
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
# z = self.layer_norm_z(z)
_offload_inference
=
True
z_reference_list
=
None
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
torch
.
cuda
.
empty_cache
()
# [*, N, C_s]
s_initial
=
s
...
...
@@ -786,7 +808,8 @@ class StructureModule(nn.Module):
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
# inplace z
s
=
s
+
self
.
ipa
(
s
,
z_reference_list
,
rigids
,
mask
,
_offload_inference
=
_offload_inference
)
s
=
self
.
ipa_dropout
(
s
)
torch
.
cuda
.
empty_cache
()
s
=
self
.
layer_norm_ipa
(
s
)
...
...
@@ -839,6 +862,11 @@ class StructureModule(nn.Module):
if
i
<
(
self
.
no_blocks
-
1
):
rigids
=
rigids
.
stop_rot_gradient
()
del
z
,
z_reference_list
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
...
...
@@ -847,8 +875,7 @@ class StructureModule(nn.Module):
def
_forward_multimer
(
self
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
evoformer_output_dict
,
aatype
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Dict
[
str
,
Any
]:
...
...
@@ -916,8 +943,7 @@ class StructureModule(nn.Module):
def
forward
(
self
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
evoformer_output_dict
,
aatype
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
...
...
@@ -935,9 +961,9 @@ class StructureModule(nn.Module):
A dictionary of outputs
"""
if
self
.
is_multimer
:
outputs
=
self
.
_forward_multimer
(
s
,
z
,
aatype
,
mask
)
outputs
=
self
.
_forward_multimer
(
evoformer_output_dict
,
aatype
,
mask
)
else
:
outputs
=
self
.
_forward_monomer
(
s
,
z
,
aatype
,
mask
)
outputs
=
self
.
_forward_monomer
(
evoformer_output_dict
,
aatype
,
mask
)
return
outputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment