Commit 5e4d790d authored by zhuww's avatar zhuww
Browse files

use offload inference on InvariantPointAttention

parent 75b87f63
...@@ -393,48 +393,40 @@ class AlphaFold(nn.Module): ...@@ -393,48 +393,40 @@ class AlphaFold(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
if no_iter == 3:
outputs["msa"] = m[..., :n_seq, :, :] outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z outputs["pair"] = z
outputs["single"] = s outputs["single"] = s
del z
# Predict 3D structure # Predict 3D structure
outputs_sm = self.structure_module( outputs["sm"] = self.structure_module(
s, outputs,
z,
feats["aatype"], feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype), mask=feats["seq_mask"].to(dtype=s.dtype),
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
if no_iter == 3:
m_1_prev, z_prev, x_prev = None, None, None
outputs["sm"] = outputs_sm
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
) )
outputs["final_atom_mask"] = feats["atom37_atom_exists"] outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1] outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
else:
# Save embeddings for use during the next recycling iteration
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m[..., 0, :, :] m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z z_prev = outputs["pair"]
# [*, N, 3] # [*, N, 3]
x_prev = atom14_to_atom37( x_prev = outputs["final_atom_positions"]
outputs_sm["positions"][-1], feats
)
if no_iter != 3:
return None, m_1_prev, z_prev, x_prev
else:
return outputs, m_1_prev, z_prev, x_prev return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self): def _disable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = None self.template_embedder.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None self.evoformer.blocks_per_ckpt = None
...@@ -537,6 +529,7 @@ class AlphaFold(nn.Module): ...@@ -537,6 +529,7 @@ class AlphaFold(nn.Module):
) )
if cycle_no != 3: if cycle_no != 3:
del outputs
prevs = [m_1_prev, z_prev, x_prev] prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev del m_1_prev, z_prev, x_prev
......
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
# limitations under the License. # limitations under the License.
import math import math
import sys
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
from fastfold.model.nn.primitives import Linear, LayerNorm, ipa_point_weights_init_ from fastfold.model.nn.primitives import Linear, ipa_point_weights_init_
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.common.residue_constants import ( from fastfold.common.residue_constants import (
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
...@@ -292,6 +294,7 @@ class InvariantPointAttention(nn.Module): ...@@ -292,6 +294,7 @@ class InvariantPointAttention(nn.Module):
z: torch.Tensor, z: torch.Tensor,
r: Union[Rigid, Rigid3Array], r: Union[Rigid, Rigid3Array],
mask: torch.Tensor, mask: torch.Tensor,
_offload_inference: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -380,7 +383,11 @@ class InvariantPointAttention(nn.Module): ...@@ -380,7 +383,11 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores # Compute attention scores
########################## ##########################
# [*, N_res, N_res, H] # [*, N_res, N_res, H]
b = self.linear_b(z) b = self.linear_b(z[0])
if(_offload_inference):
assert(sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
a = torch.matmul( a = torch.matmul(
...@@ -508,14 +515,17 @@ class InvariantPointAttention(nn.Module): ...@@ -508,14 +515,17 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3] # [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if(_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
del a
torch.cuda.empty_cache() torch.cuda.empty_cache()
# [*, N_res, H * C_z] # [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2) o_pair = flatten_final_dims(o_pair, 2)
del a torch.cuda.empty_cache()
# [*, N_res, C_s] # [*, N_res, C_s]
if self.is_multimer: if self.is_multimer:
...@@ -526,7 +536,7 @@ class InvariantPointAttention(nn.Module): ...@@ -526,7 +536,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out( s = self.linear_out(
torch.cat( torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype) ).to(dtype=z[0].dtype)
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -737,8 +747,7 @@ class StructureModule(nn.Module): ...@@ -737,8 +747,7 @@ class StructureModule(nn.Module):
def _forward_monomer( def _forward_monomer(
self, self,
s: torch.Tensor, evoformer_output_dict,
z: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -755,6 +764,8 @@ class StructureModule(nn.Module): ...@@ -755,6 +764,8 @@ class StructureModule(nn.Module):
Returns: Returns:
A dictionary of outputs A dictionary of outputs
""" """
s = evoformer_output_dict["single"]
if mask is None: if mask is None:
# [*, N] # [*, N]
mask = s.new_ones(s.shape[:-1]) mask = s.new_ones(s.shape[:-1])
...@@ -765,9 +776,20 @@ class StructureModule(nn.Module): ...@@ -765,9 +776,20 @@ class StructureModule(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(z)
# inplace z # inplace z
# z[0] = z[0].contiguous() evoformer_output_dict["pair"] = evoformer_output_dict["pair"].contiguous()
# torch.cuda.emtpy_cache() torch.cuda.empty_cache()
# z[0] = self.layer_norm_z(z[0]) z = self.layer_norm_z(evoformer_output_dict["pair"])
# z = self.layer_norm_z(z)
_offload_inference = True
z_reference_list = None
if(_offload_inference):
assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
torch.cuda.empty_cache()
# [*, N, C_s] # [*, N, C_s]
s_initial = s s_initial = s
...@@ -786,7 +808,8 @@ class StructureModule(nn.Module): ...@@ -786,7 +808,8 @@ class StructureModule(nn.Module):
outputs = [] outputs = []
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask) # inplace z
s = s + self.ipa(s, z_reference_list, rigids, mask, _offload_inference=_offload_inference)
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
torch.cuda.empty_cache() torch.cuda.empty_cache()
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
...@@ -840,6 +863,11 @@ class StructureModule(nn.Module): ...@@ -840,6 +863,11 @@ class StructureModule(nn.Module):
if i < (self.no_blocks - 1): if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient() rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if(_offload_inference):
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
outputs = dict_multimap(torch.stack, outputs) outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s outputs["single"] = s
...@@ -847,8 +875,7 @@ class StructureModule(nn.Module): ...@@ -847,8 +875,7 @@ class StructureModule(nn.Module):
def _forward_multimer( def _forward_multimer(
self, self,
s: torch.Tensor, evoformer_output_dict,
z: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -916,8 +943,7 @@ class StructureModule(nn.Module): ...@@ -916,8 +943,7 @@ class StructureModule(nn.Module):
def forward( def forward(
self, self,
s: torch.Tensor, evoformer_output_dict,
z: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
): ):
...@@ -935,9 +961,9 @@ class StructureModule(nn.Module): ...@@ -935,9 +961,9 @@ class StructureModule(nn.Module):
A dictionary of outputs A dictionary of outputs
""" """
if self.is_multimer: if self.is_multimer:
outputs = self._forward_multimer(s, z, aatype, mask) outputs = self._forward_multimer(evoformer_output_dict, aatype, mask)
else: else:
outputs = self._forward_monomer(s, z, aatype, mask) outputs = self._forward_monomer(evoformer_output_dict, aatype, mask)
return outputs return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment