Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
This diff is collapsed.
...@@ -22,6 +22,7 @@ from openfold.utils.loss import ( ...@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm, compute_tm,
compute_predicted_aligned_error, compute_predicted_aligned_error,
) )
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module): class AuxiliaryHeads(nn.Module):
...@@ -137,7 +138,7 @@ class DistogramHead(nn.Module): ...@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final") self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z] def _forward(self, z): # [*, N, N, C_z]
""" """
Args: Args:
z: z:
...@@ -149,6 +150,13 @@ class DistogramHead(nn.Module): ...@@ -149,6 +150,13 @@ class DistogramHead(nn.Module):
logits = self.linear(z) logits = self.linear(z)
logits = logits + logits.transpose(-2, -3) logits = logits + logits.transpose(-2, -3)
return logits return logits
def forward(self, z):
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
class TMScoreHead(nn.Module): class TMScoreHead(nn.Module):
......
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import weakref
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -34,12 +35,26 @@ from openfold.model.embedders import ( ...@@ -34,12 +35,26 @@ from openfold.model.embedders import (
) )
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule from openfold.model.structure_module import StructureModule
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
embed_templates_average,
embed_templates_offload,
)
import openfold.np.residue_constants as residue_constants
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add,
dict_multimap, dict_multimap,
tensor_tree_map, tensor_tree_map,
) )
...@@ -61,55 +76,96 @@ class AlphaFold(nn.Module): ...@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
super(AlphaFold, self).__init__() super(AlphaFold, self).__init__()
self.globals = config.globals self.globals = config.globals
config = config.model self.config = config.model
template_config = config.template self.template_config = self.config.template
extra_msa_config = config.extra_msa self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module # Main trunk + structure module
if(self.globals.is_multimer): if(self.globals.is_multimer):
self.input_embedder = InputEmbedderMultimer( self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"], **self.config["input_embedder"],
) )
else: else:
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
**config["input_embedder"], **self.config["input_embedder"],
) )
self.recycling_embedder = RecyclingEmbedder( self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"], **self.config["recycling_embedder"],
) )
if(self.globals.is_multimer): if (self.template_config.enabled):
self.template_embedder = TemplateEmbedderMultimer( if(self.globals.is_multimer):
template_config, self.template_embedder = TemplateEmbedderMultimer(
self.template_config,
)
else:
self.template_embedder = TemplateEmbedder(
self.template_config,
)
if (self.extra_msa_config.enabled):
self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"],
) )
else: self.extra_msa_stack = ExtraMSAStack(
self.template_embedder = TemplateEmbedder( **self.extra_msa_config["extra_msa_stack"],
template_config,
) )
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack( self.evoformer = EvoformerStack(
**config["evoformer_stack"], **self.config["evoformer_stack"],
) )
self.structure_module = StructureModule( self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer, is_multimer=self.globals.is_multimer,
**config["structure_module"], **self.config["structure_module"],
) )
self.aux_heads = AuxiliaryHeads( self.aux_heads = AuxiliaryHeads(
config["heads"], self.config["heads"],
) )
self.config = config def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if (self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if (self.template_config.offload_templates):
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif (self.template_config.average_templates):
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): return template_embeds
def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
...@@ -125,19 +181,38 @@ class AlphaFold(nn.Module): ...@@ -125,19 +181,38 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2] n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3] n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features # Prep some features
seq_mask = feats["seq_mask"] seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :] pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"] msa_mask = feats["msa_mask"]
## Initialize the MSA and pair representations
# Initialize the MSA and pair representations if (self.globals.is_multimer):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
# m: [*, S_c, N, C_m] else:
# z: [*, N, N, C_z] # m: [*, S_c, N, C_m]
m, z = self.input_embedder(feats) # z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Initialize the recycling embeddings, if needs be # Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m.new_zeros( m_1_prev = m.new_zeros(
...@@ -161,69 +236,58 @@ class AlphaFold(nn.Module): ...@@ -161,69 +236,58 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None feats["aatype"], x_prev, None
).to(dtype=z.dtype) ).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
inplace_safe=inplace_safe,
) )
# If the number of recycling iterations is 0, skip recycling if(self.globals.offload_inference and inplace_safe):
# altogether. We zero them this way instead of computing them m = m.to(m_1_prev_emb.device)
# conditionally to avoid leaving parameters unused, which has annoying z = z.to(z_prev.device)
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z] # [*, N, N, C_z]
z += z_prev_emb z = add(z, z_prev_emb, inplace=inplace_safe)
# Possibly prevents memory fragmentation # Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_feats = { template_feats = {
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
if(self.globals.is_multimer): template_embeds = self.embed_templates(
asym_id = feats["asym_id"] template_feats,
multichain_mask_2d = ( feats,
asym_id[..., None] == asym_id[..., None, :] z,
) pair_mask.to(dtype=z.dtype),
template_embeds = self.template_embedder( no_batch_dims,
template_feats, inplace_safe=inplace_safe,
z, )
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size
)
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = add(z,
template_embeds.pop("template_pair_embedding"),
inplace_safe,
)
if( if(
self.config.template.embed_angles or "template_single_embedding" in template_embeds
(self.globals.is_multimer and self.config.template.enabled)
): ):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
...@@ -253,41 +317,80 @@ class AlphaFold(nn.Module): ...@@ -253,41 +317,80 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats) extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat) a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# [*, N, N, C_z] # Run MSA + pair embeddings through the trunk of the network
z = self.extra_msa_stack( # m: [*, S, N, C_m]
extra_msa_feat, # z: [*, N, N, C_z]
z, # s: [*, N, C_s]
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype), if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :] outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z outputs["pair"] = z
outputs["single"] = s outputs["single"] = s
del z
# Predict 3D structure # Predict 3D structure
outputs["sm"] = self.structure_module( outputs["sm"] = self.structure_module(
s, outputs,
z,
feats["aatype"], feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype), mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
) )
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
...@@ -301,7 +404,7 @@ class AlphaFold(nn.Module): ...@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
m_1_prev = m[..., 0, :, :] m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z z_prev = outputs["pair"]
# [*, N, 3] # [*, N, 3]
x_prev = outputs["final_atom_positions"] x_prev = outputs["final_atom_positions"]
...@@ -379,14 +482,13 @@ class AlphaFold(nn.Module): ...@@ -379,14 +482,13 @@ class AlphaFold(nn.Module):
""" """
# Initialize recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop # Main recycling loop
num_iters = batch["aatype"].shape[-1] num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters): for cycle_no in range(num_iters):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
...@@ -395,7 +497,6 @@ class AlphaFold(nn.Module): ...@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
is_final_iter = cycle_no == (num_iters - 1) is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter: if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766) # Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
torch.clear_autocast_cache() torch.clear_autocast_cache()
...@@ -403,12 +504,15 @@ class AlphaFold(nn.Module): ...@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, feats,
m_1_prev, prevs,
z_prev,
x_prev,
_recycle=(num_iters > 1) _recycle=(num_iters > 1)
) )
if(not is_final_iter):
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads # Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
......
...@@ -26,8 +26,8 @@ from openfold.model.primitives import ( ...@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable, _attention_chunked_trainable,
) )
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
) )
...@@ -89,21 +89,38 @@ class MSAAttention(nn.Module): ...@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def _chunk(self, def _chunk(self,
m: torch.Tensor, m: torch.Tensor,
biases: List[torch.Tensor], biases: Optional[List[torch.Tensor]],
use_memory_efficient_kernel: bool,
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
mha = partial( def fn(m, biases, flash_mask):
self.mha, m = self.layer_norm_m(m)
use_memory_efficient_kernel=use_memory_efficient_kernel return self.mha(
) q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
)
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer( return chunk_layer(
mha, fn,
{ inputs,
"q_x": m,
"kv_x": m,
"biases": biases,
},
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2])
) )
...@@ -111,11 +128,9 @@ class MSAAttention(nn.Module): ...@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self, def _prep_inputs(self,
m: torch.Tensor, m: torch.Tensor,
z: Optional[torch.Tensor], z: Optional[torch.Tensor],
mask: Optional[torch.Tensor] mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: inplace_safe: bool = False,
# [*, N_seq, N_res, C_m] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1] n_seq, n_res = m.shape[-3:-1]
if mask is None: if mask is None:
# [*, N_seq, N_res] # [*, N_seq, N_res]
...@@ -131,11 +146,20 @@ class MSAAttention(nn.Module): ...@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
self.layer_norm_z is not None and # benefit of self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript self.linear_z is not None # TorchScript
): ):
# [*, N_res, N_res, C_z] chunks = []
z = self.layer_norm_z(z)
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z]
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
# [*, N_res, N_res, no_heads] z = torch.cat(chunks, dim=-3)
z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
...@@ -149,6 +173,7 @@ class MSAAttention(nn.Module): ...@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor], mask: Optional[torch.Tensor],
chunk_logits: int, chunk_logits: int,
checkpoint: bool, checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
MSA attention with training-time chunking of the softmax computation. MSA attention with training-time chunking of the softmax computation.
...@@ -158,7 +183,10 @@ class MSAAttention(nn.Module): ...@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
MSA_DIM = -4 MSA_DIM = -4
def _get_qkv(m, z): def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask) m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m) q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z return m, q, k, v, mask_bias, z
...@@ -193,6 +221,9 @@ class MSAAttention(nn.Module): ...@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None, _chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None, _checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -214,23 +245,43 @@ class MSAAttention(nn.Module): ...@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None): if(_chunk_logits is not None):
return self._chunked_msa_attn( return self._chunked_msa_attn(
m=m, z=z, mask=mask, m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks chunk_logits=_chunk_logits,
) checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
m, mask_bias, z = self._prep_inputs(m, z, mask) )
biases = [mask_bias] if(use_flash):
if(z is not None): assert z is None
biases.append(z) biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size) m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
else: else:
m = self.layer_norm_m(m)
m = self.mha( m = self.mha(
q_x=m, q_x=m,
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
) )
return m return m
...@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module): ...@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module): ...@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if mask is not None: if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size) m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
...@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
mha_input = { mha_input = {
"m": m, "m": m,
"mask": mask, "mask": mask,
} }
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer( return chunk_layer(
self.global_attention, fn,
mha_input, mha_input,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
...@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
...@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m) #m = self.layer_norm_m(m)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, mask, chunk_size) m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else: else:
m = self.global_attention(m=m, mask=mask) m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
...@@ -20,7 +20,8 @@ import torch ...@@ -20,7 +20,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module): class OuterProductMean(nn.Module):
...@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module): ...@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
no_batch_dims=1, no_batch_dims=1,
) )
out.append(outer) out.append(outer)
outer = torch.stack(out, dim=0)
# For some cursed reason making this distinction saves memory
if(len(out) == 1):
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer return outer
def forward(self, def _forward(self,
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module): ...@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
mask = m.new_ones(m.shape[:-1]) mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm(m) ln = self.layer_norm(m)
# [*, N_seq, N_res, C] # [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask a = self.linear_1(ln)
b = self.linear_2(m) * mask a = a * mask
b = self.linear_2(ln)
b = b * mask
del ln
a = a.transpose(-2, -3) a = a.transpose(-2, -3)
b = b.transpose(-2, -3) b = b.transpose(-2, -3)
...@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module): ...@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1] # [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask) norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
norm = norm + self.eps
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
outer = outer / (self.eps + norm) if(inplace_safe):
outer /= norm
else:
outer = outer / norm
return outer return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module): class PairTransition(nn.Module):
...@@ -46,12 +46,16 @@ class PairTransition(nn.Module): ...@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask): def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden] # [*, N_res, N_res, C_hidden]
z = self.linear_1(z) z = self.linear_1(z)
z = self.relu(z) z = self.relu(z)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask z = self.linear_2(z)
z = z * mask
return z return z
...@@ -68,7 +72,6 @@ class PairTransition(nn.Module): ...@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims=len(z.shape[:-2]), no_batch_dims=len(z.shape[:-2]),
) )
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
...@@ -88,9 +91,6 @@ class PairTransition(nn.Module): ...@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1] # [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None: if chunk_size is not None:
z = self._chunk(z, mask, chunk_size) z = self._chunk(z, mask, chunk_size)
else: else:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
...@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure). ...@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
import io import io
import pdbfixer import pdbfixer
from simtk.openmm import app try:
from simtk.openmm.app import element # openmm >= 7.6
from openmm import app
from openmm.app import element
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info): def fix_pdb(pdbfile, alterations_info):
......
...@@ -87,4 +87,7 @@ class AmberRelaxation(object): ...@@ -87,4 +87,7 @@ class AmberRelaxation(object):
violations = out["structural_violations"][ violations = out["structural_violations"][
"total_per_residue_violations_mask" "total_per_residue_violations_mask"
] ]
min_pdb = protein.add_pdb_headers(prot, min_pdb)
return min_pdb, debug_data, violations return min_pdb, debug_data, violations
...@@ -18,8 +18,14 @@ import io ...@@ -18,8 +18,14 @@ import io
from openfold.np import residue_constants from openfold.np import residue_constants
from Bio import PDB from Bio import PDB
import numpy as np import numpy as np
from simtk.openmm import app as openmm_app try:
from simtk.openmm.app.internal.pdbstructure import PdbStructure # openmm >= 7.6
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment