Commit b40fab25 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add offloading to structure module

parent 7d9e6830
......@@ -19,7 +19,7 @@ from operator import mul
import torch
import torch.nn as nn
from typing import Optional, Tuple
from typing import Optional, Tuple, Sequence
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import (
......@@ -229,9 +229,11 @@ class InvariantPointAttention(nn.Module):
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
z: Optional[torch.Tensor],
r: Rigid,
mask: torch.Tensor,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
......@@ -247,6 +249,10 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation update
"""
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
#######################################
# Generate scalar and point activations
......@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z)
b = self.linear_b(z[0])
if(_offload_inference):
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
a = torch.matmul(
......@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if(_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
......@@ -402,9 +414,9 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
).to(dtype=z[0].dtype)
)
return s
......@@ -604,17 +616,19 @@ class StructureModule(nn.Module):
def forward(
self,
s,
z,
evoformer_output_dict,
aatype,
mask=None,
_offload_inference=False,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
evoformer_output_dict:
Dictionary containing:
"single":
[*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
......@@ -626,11 +640,19 @@ class StructureModule(nn.Module):
# [*, N]
mask = s.new_ones(s.shape[:-1])
s = evoformer_output_dict["single"]
# [*, N, C_s]
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(z)
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if(_offload_inference):
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s]
s_initial = s
......@@ -647,11 +669,18 @@ class StructureModule(nn.Module):
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
s = s + self.ipa(
s,
z,
rigids,
mask,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s))
......@@ -698,6 +727,13 @@ class StructureModule(nn.Module):
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if(_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment