Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
578541c8
"pcdet/ops/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "19068b52a182c04694305c4542fede9dddc4d527"
Commit
578541c8
authored
Jun 13, 2022
by
Gustaf Ahdritz
Browse files
Improve memory efficiency of structure module
parent
e9e3fbdc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
17 deletions
+48
-17
openfold/model/structure_module.py
openfold/model/structure_module.py
+48
-17
No files found.
openfold/model/structure_module.py
View file @
578541c8
...
...
@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
importlib
import
math
from
operator
import
mul
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
...
...
@@ -36,6 +39,8 @@ from openfold.utils.tensor_utils import (
flatten_final_dims
,
)
attn_core_inplace_cuda
=
importlib
.
import_module
(
"attn_core_inplace_cuda"
)
class
AngleResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
):
...
...
@@ -241,6 +246,8 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
#######################################
# Generate scalar and point activations
#######################################
...
...
@@ -303,7 +310,10 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
**
2
if
(
inplace_safe
):
pt_att
*=
pt_att
else
:
pt_att
=
pt_att
**
2
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
...
...
@@ -313,7 +323,10 @@ class InvariantPointAttention(nn.Module):
head_weights
=
head_weights
*
math
.
sqrt
(
1.0
/
(
3
*
(
self
.
no_qk_points
*
9.0
/
2
))
)
pt_att
=
pt_att
*
head_weights
if
(
inplace_safe
):
pt_att
*=
head_weights
else
:
pt_att
=
pt_att
*
head_weights
# [*, N_res, N_res, H]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
...
...
@@ -323,9 +336,21 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
))
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
if
(
inplace_safe
):
a
+=
pt_att
del
pt_att
a
+=
square_mask
.
unsqueeze
(
-
3
)
# in-place softmax
attn_core_inplace_cuda
.
forward_
(
a
,
reduce
(
mul
,
a
.
shape
[:
-
1
]),
a
.
shape
[
-
1
],
)
else
:
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
################
# Compute output
...
...
@@ -338,16 +363,22 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# [*, H, 3, N_res, P_v]
if
(
inplace_safe
):
v_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))
o_pt
=
[
torch
.
matmul
(
a
,
v
.
to
(
a
.
dtype
))
for
v
in
torch
.
unbind
(
v_pts
,
dim
=-
3
)
]
o_pt
=
torch
.
stack
(
o_pt
,
dim
=-
3
)
else
:
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
...
...
@@ -620,7 +651,7 @@ class StructureModule(nn.Module):
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment