Commit 578541c8 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve memory efficiency of structure module

parent e9e3fbdc
......@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import importlib
import math
from operator import mul
import torch
import torch.nn as nn
from typing import Optional, Tuple
......@@ -36,6 +39,8 @@ from openfold.utils.tensor_utils import (
flatten_final_dims,
)
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
class AngleResnetBlock(nn.Module):
def __init__(self, c_hidden):
......@@ -241,6 +246,8 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
inplace_safe = not (self.training or torch.is_grad_enabled())
#######################################
# Generate scalar and point activations
#######################################
......@@ -303,7 +310,10 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
if(inplace_safe):
pt_att *= pt_att
else:
pt_att = pt_att ** 2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
......@@ -313,7 +323,10 @@ class InvariantPointAttention(nn.Module):
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
)
pt_att = pt_att * head_weights
if(inplace_safe):
pt_att *= head_weights
else:
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
......@@ -323,9 +336,21 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
if(inplace_safe):
a += pt_att
del pt_att
a += square_mask.unsqueeze(-3)
# in-place softmax
attn_core_inplace_cuda.forward_(
a,
reduce(mul, a.shape[:-1]),
a.shape[-1],
)
else:
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################
# Compute output
......@@ -338,16 +363,22 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, H, 3, N_res, P_v]
if(inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
for v in torch.unbind(v_pts, dim=-3)
]
o_pt = torch.stack(o_pt, dim=-3)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
......@@ -620,7 +651,7 @@ class StructureModule(nn.Module):
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment