Commit 578541c8 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve memory efficiency of structure module

parent e9e3fbdc
...@@ -12,8 +12,11 @@ ...@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import reduce
import importlib
import math import math
from operator import mul
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -36,6 +39,8 @@ from openfold.utils.tensor_utils import ( ...@@ -36,6 +39,8 @@ from openfold.utils.tensor_utils import (
flatten_final_dims, flatten_final_dims,
) )
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
class AngleResnetBlock(nn.Module): class AngleResnetBlock(nn.Module):
def __init__(self, c_hidden): def __init__(self, c_hidden):
...@@ -241,6 +246,8 @@ class InvariantPointAttention(nn.Module): ...@@ -241,6 +246,8 @@ class InvariantPointAttention(nn.Module):
Returns: Returns:
[*, N_res, C_s] single representation update [*, N_res, C_s] single representation update
""" """
inplace_safe = not (self.training or torch.is_grad_enabled())
####################################### #######################################
# Generate scalar and point activations # Generate scalar and point activations
####################################### #######################################
...@@ -303,7 +310,10 @@ class InvariantPointAttention(nn.Module): ...@@ -303,7 +310,10 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2 if(inplace_safe):
pt_att *= pt_att
else:
pt_att = pt_att ** 2
# [*, N_res, N_res, H, P_q] # [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1)) pt_att = sum(torch.unbind(pt_att, dim=-1))
...@@ -313,7 +323,10 @@ class InvariantPointAttention(nn.Module): ...@@ -313,7 +323,10 @@ class InvariantPointAttention(nn.Module):
head_weights = head_weights * math.sqrt( head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2)) 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
) )
pt_att = pt_att * head_weights if(inplace_safe):
pt_att *= head_weights
else:
pt_att = pt_att * head_weights
# [*, N_res, N_res, H] # [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5) pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
...@@ -323,9 +336,21 @@ class InvariantPointAttention(nn.Module): ...@@ -323,9 +336,21 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1)) pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3) if(inplace_safe):
a = self.softmax(a) a += pt_att
del pt_att
a += square_mask.unsqueeze(-3)
# in-place softmax
attn_core_inplace_cuda.forward_(
a,
reduce(mul, a.shape[:-1]),
a.shape[-1],
)
else:
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################ ################
# Compute output # Compute output
...@@ -338,16 +363,22 @@ class InvariantPointAttention(nn.Module): ...@@ -338,16 +363,22 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2) o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation # [*, H, 3, N_res, P_v]
# happens in float32. if(inplace_safe):
# [*, H, 3, N_res, P_v] v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = torch.sum( o_pt = [
( torch.matmul(a, v.to(a.dtype))
a[..., None, :, :, None] for v in torch.unbind(v_pts, dim=-3)
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] ]
), o_pt = torch.stack(o_pt, dim=-3)
dim=-2, else:
) o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
...@@ -620,7 +651,7 @@ class StructureModule(nn.Module): ...@@ -620,7 +651,7 @@ class StructureModule(nn.Module):
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
s = self.transition(s) s = self.transition(s)
# [*, N] # [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s)) rigids = rigids.compose_q_update_vec(self.bb_update(s))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment