Commit 578541c8 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve memory efficiency of structure module

parent e9e3fbdc
......@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import importlib
import math
from operator import mul
import torch
import torch.nn as nn
from typing import Optional, Tuple
......@@ -36,6 +39,8 @@ from openfold.utils.tensor_utils import (
flatten_final_dims,
)
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
class AngleResnetBlock(nn.Module):
def __init__(self, c_hidden):
......@@ -241,6 +246,8 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
inplace_safe = not (self.training or torch.is_grad_enabled())
#######################################
# Generate scalar and point activations
#######################################
......@@ -303,6 +310,9 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
if(inplace_safe):
pt_att *= pt_att
else:
pt_att = pt_att ** 2
# [*, N_res, N_res, H, P_q]
......@@ -313,6 +323,9 @@ class InvariantPointAttention(nn.Module):
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
)
if(inplace_safe):
pt_att *= head_weights
else:
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
......@@ -323,6 +336,18 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
if(inplace_safe):
a += pt_att
del pt_att
a += square_mask.unsqueeze(-3)
# in-place softmax
attn_core_inplace_cuda.forward_(
a,
reduce(mul, a.shape[:-1]),
a.shape[-1],
)
else:
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
......@@ -338,9 +363,15 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
if(inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
for v in torch.unbind(v_pts, dim=-3)
]
o_pt = torch.stack(o_pt, dim=-3)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment