Commit b40fab25 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add offloading to structure module

parent 7d9e6830
...@@ -19,7 +19,7 @@ from operator import mul ...@@ -19,7 +19,7 @@ from operator import mul
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple, Sequence
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_ from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
...@@ -229,9 +229,11 @@ class InvariantPointAttention(nn.Module): ...@@ -229,9 +229,11 @@ class InvariantPointAttention(nn.Module):
def forward( def forward(
self, self,
s: torch.Tensor, s: torch.Tensor,
z: torch.Tensor, z: Optional[torch.Tensor],
r: Rigid, r: Rigid,
mask: torch.Tensor, mask: torch.Tensor,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -247,6 +249,10 @@ class InvariantPointAttention(nn.Module): ...@@ -247,6 +249,10 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation update [*, N_res, C_s] single representation update
""" """
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
####################################### #######################################
# Generate scalar and point activations # Generate scalar and point activations
...@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module): ...@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores # Compute attention scores
########################## ##########################
# [*, N_res, N_res, H] # [*, N_res, N_res, H]
b = self.linear_b(z) b = self.linear_b(z[0])
if(_offload_inference):
z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
a = torch.matmul( a = torch.matmul(
...@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module): ...@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3] # [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if(_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z] # [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2) o_pair = flatten_final_dims(o_pair, 2)
...@@ -402,9 +414,9 @@ class InvariantPointAttention(nn.Module): ...@@ -402,9 +414,9 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out( s = self.linear_out(
torch.cat( torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype) ).to(dtype=z[0].dtype)
) )
return s return s
...@@ -604,17 +616,19 @@ class StructureModule(nn.Module): ...@@ -604,17 +616,19 @@ class StructureModule(nn.Module):
def forward( def forward(
self, self,
s, evoformer_output_dict,
z,
aatype, aatype,
mask=None, mask=None,
_offload_inference=False,
): ):
""" """
Args: Args:
s: evoformer_output_dict:
[*, N_res, C_s] single representation Dictionary containing:
z: "single":
[*, N_res, N_res, C_z] pair representation [*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype: aatype:
[*, N_res] amino acid indices [*, N_res] amino acid indices
mask: mask:
...@@ -626,11 +640,19 @@ class StructureModule(nn.Module): ...@@ -626,11 +640,19 @@ class StructureModule(nn.Module):
# [*, N] # [*, N]
mask = s.new_ones(s.shape[:-1]) mask = s.new_ones(s.shape[:-1])
s = evoformer_output_dict["single"]
# [*, N, C_s] # [*, N, C_s]
s = self.layer_norm_s(s) s = self.layer_norm_s(s)
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if(_offload_inference):
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s] # [*, N, C_s]
s_initial = s s_initial = s
...@@ -647,11 +669,18 @@ class StructureModule(nn.Module): ...@@ -647,11 +669,18 @@ class StructureModule(nn.Module):
outputs = [] outputs = []
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask) s = s + self.ipa(
s,
z,
rigids,
mask,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
s = self.transition(s) s = self.transition(s)
# [*, N] # [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s)) rigids = rigids.compose_q_update_vec(self.bb_update(s))
...@@ -698,6 +727,13 @@ class StructureModule(nn.Module): ...@@ -698,6 +727,13 @@ class StructureModule(nn.Module):
rigids = rigids.stop_rot_gradient() rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if(_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
outputs = dict_multimap(torch.stack, outputs) outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s outputs["single"] = s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment