Commit 5e4d790d authored by zhuww's avatar zhuww
Browse files

use offload inference on InvariantPointAttention

parent 75b87f63
......@@ -393,47 +393,39 @@ class AlphaFold(nn.Module):
torch.cuda.empty_cache()
if no_iter == 3:
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
del z
# Predict 3D structure
outputs_sm = self.structure_module(
s,
z,
outputs["sm"] = self.structure_module(
outputs,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
)
torch.cuda.empty_cache()
if no_iter == 3:
m_1_prev, z_prev, x_prev = None, None, None
outputs["sm"] = outputs_sm
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
else:
# Save embeddings for use during the next recycling iteration
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = outputs["pair"]
# [*, N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
# [*, N, 3]
x_prev = atom14_to_atom37(
outputs_sm["positions"][-1], feats
)
if no_iter != 3:
return None, m_1_prev, z_prev, x_prev
else:
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = None
......@@ -537,6 +529,7 @@ class AlphaFold(nn.Module):
)
if cycle_no != 3:
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
......
......@@ -14,11 +14,13 @@
# limitations under the License.
import math
import sys
import torch
import torch.nn as nn
from typing import Any, Dict, Optional, Tuple, Union
from fastfold.model.nn.primitives import Linear, LayerNorm, ipa_point_weights_init_
from fastfold.model.nn.primitives import Linear, ipa_point_weights_init_
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.common.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -292,6 +294,7 @@ class InvariantPointAttention(nn.Module):
z: torch.Tensor,
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
_offload_inference: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -380,7 +383,11 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z)
b = self.linear_b(z[0])
if(_offload_inference):
assert(sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
a = torch.matmul(
......@@ -508,14 +515,17 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if(_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
del a
torch.cuda.empty_cache()
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
del a
torch.cuda.empty_cache()
# [*, N_res, C_s]
if self.is_multimer:
......@@ -526,7 +536,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
).to(dtype=z[0].dtype)
)
torch.cuda.empty_cache()
......@@ -737,8 +747,7 @@ class StructureModule(nn.Module):
def _forward_monomer(
self,
s: torch.Tensor,
z: torch.Tensor,
evoformer_output_dict,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
......@@ -755,6 +764,8 @@ class StructureModule(nn.Module):
Returns:
A dictionary of outputs
"""
s = evoformer_output_dict["single"]
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
......@@ -765,9 +776,20 @@ class StructureModule(nn.Module):
# [*, N, N, C_z]
z = self.layer_norm_z(z)
# inplace z
# z[0] = z[0].contiguous()
# torch.cuda.emtpy_cache()
# z[0] = self.layer_norm_z(z[0])
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].contiguous()
torch.cuda.empty_cache()
z = self.layer_norm_z(evoformer_output_dict["pair"])
# z = self.layer_norm_z(z)
_offload_inference = True
z_reference_list = None
if(_offload_inference):
assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
torch.cuda.empty_cache()
# [*, N, C_s]
s_initial = s
......@@ -786,7 +808,8 @@ class StructureModule(nn.Module):
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
# inplace z
s = s + self.ipa(s, z_reference_list, rigids, mask, _offload_inference=_offload_inference)
s = self.ipa_dropout(s)
torch.cuda.empty_cache()
s = self.layer_norm_ipa(s)
......@@ -839,6 +862,11 @@ class StructureModule(nn.Module):
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if(_offload_inference):
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
......@@ -847,8 +875,7 @@ class StructureModule(nn.Module):
def _forward_multimer(
self,
s: torch.Tensor,
z: torch.Tensor,
evoformer_output_dict,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
......@@ -916,8 +943,7 @@ class StructureModule(nn.Module):
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
evoformer_output_dict,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
......@@ -935,9 +961,9 @@ class StructureModule(nn.Module):
A dictionary of outputs
"""
if self.is_multimer:
outputs = self._forward_multimer(s, z, aatype, mask)
outputs = self._forward_multimer(evoformer_output_dict, aatype, mask)
else:
outputs = self._forward_monomer(s, z, aatype, mask)
outputs = self._forward_monomer(evoformer_output_dict, aatype, mask)
return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment