Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
This diff is collapsed.
...@@ -22,6 +22,7 @@ from openfold.utils.loss import ( ...@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm, compute_tm,
compute_predicted_aligned_error, compute_predicted_aligned_error,
) )
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module): class AuxiliaryHeads(nn.Module):
...@@ -137,7 +138,7 @@ class DistogramHead(nn.Module): ...@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final") self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z] def _forward(self, z): # [*, N, N, C_z]
""" """
Args: Args:
z: z:
...@@ -150,6 +151,13 @@ class DistogramHead(nn.Module): ...@@ -150,6 +151,13 @@ class DistogramHead(nn.Module):
logits = logits + logits.transpose(-2, -3) logits = logits + logits.transpose(-2, -3)
return logits return logits
def forward(self, z):
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
class TMScoreHead(nn.Module): class TMScoreHead(nn.Module):
""" """
......
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import weakref
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -34,12 +35,26 @@ from openfold.model.embedders import ( ...@@ -34,12 +35,26 @@ from openfold.model.embedders import (
) )
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule from openfold.model.structure_module import StructureModule
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
embed_templates_average,
embed_templates_offload,
)
import openfold.np.residue_constants as residue_constants
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add,
dict_multimap, dict_multimap,
tensor_tree_map, tensor_tree_map,
) )
...@@ -61,55 +76,96 @@ class AlphaFold(nn.Module): ...@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
super(AlphaFold, self).__init__() super(AlphaFold, self).__init__()
self.globals = config.globals self.globals = config.globals
config = config.model self.config = config.model
template_config = config.template self.template_config = self.config.template
extra_msa_config = config.extra_msa self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module # Main trunk + structure module
if(self.globals.is_multimer): if(self.globals.is_multimer):
self.input_embedder = InputEmbedderMultimer( self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"], **self.config["input_embedder"],
) )
else: else:
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
**config["input_embedder"], **self.config["input_embedder"],
) )
self.recycling_embedder = RecyclingEmbedder( self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"], **self.config["recycling_embedder"],
) )
if (self.template_config.enabled):
if(self.globals.is_multimer): if(self.globals.is_multimer):
self.template_embedder = TemplateEmbedderMultimer( self.template_embedder = TemplateEmbedderMultimer(
template_config, self.template_config,
) )
else: else:
self.template_embedder = TemplateEmbedder( self.template_embedder = TemplateEmbedder(
template_config, self.template_config,
) )
if (self.extra_msa_config.enabled):
self.extra_msa_embedder = ExtraMSAEmbedder( self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"], **self.extra_msa_config["extra_msa_embedder"],
) )
self.extra_msa_stack = ExtraMSAStack( self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"], **self.extra_msa_config["extra_msa_stack"],
) )
self.evoformer = EvoformerStack( self.evoformer = EvoformerStack(
**config["evoformer_stack"], **self.config["evoformer_stack"],
) )
self.structure_module = StructureModule( self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer, is_multimer=self.globals.is_multimer,
**config["structure_module"], **self.config["structure_module"],
) )
self.aux_heads = AuxiliaryHeads( self.aux_heads = AuxiliaryHeads(
config["heads"], self.config["heads"],
)
def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if (self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if (self.template_config.offload_templates):
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif (self.template_config.average_templates):
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
self.config = config template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
return template_embeds
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
...@@ -126,17 +182,36 @@ class AlphaFold(nn.Module): ...@@ -126,17 +182,36 @@ class AlphaFold(nn.Module):
n_seq = feats["msa_feat"].shape[-3] n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features # Prep some features
seq_mask = feats["seq_mask"] seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :] pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"] msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations ## Initialize the MSA and pair representations
if (self.globals.is_multimer):
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder(feats) m, z = self.input_embedder(feats)
else:
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m] # [*, N, C_m]
...@@ -161,30 +236,33 @@ class AlphaFold(nn.Module): ...@@ -161,30 +236,33 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None feats["aatype"], x_prev, None
).to(dtype=z.dtype) ).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
inplace_safe=inplace_safe,
) )
# If the number of recycling iterations is 0, skip recycling if(self.globals.offload_inference and inplace_safe):
# altogether. We zero them this way instead of computing them m = m.to(m_1_prev_emb.device)
# conditionally to avoid leaving parameters unused, which has annoying z = z.to(z_prev.device)
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z] # [*, N, N, C_z]
z += z_prev_emb z = add(z, z_prev_emb, inplace=inplace_safe)
# Possibly prevents memory fragmentation # Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
...@@ -193,37 +271,23 @@ class AlphaFold(nn.Module): ...@@ -193,37 +271,23 @@ class AlphaFold(nn.Module):
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
if(self.globals.is_multimer): template_embeds = self.embed_templates(
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_feats, template_feats,
feats,
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
no_batch_dims, no_batch_dims,
self.globals.chunk_size inplace_safe=inplace_safe,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = add(z,
template_embeds.pop("template_pair_embedding"),
inplace_safe,
)
if( if(
self.config.template.embed_angles or "template_single_embedding" in template_embeds
(self.globals.is_multimer and self.config.template.enabled)
): ):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
...@@ -253,15 +317,34 @@ class AlphaFold(nn.Module): ...@@ -253,15 +317,34 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats) extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat) a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.extra_msa_stack( z = self.extra_msa_stack(
extra_msa_feat, a, z,
z, msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z.dtype), use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -269,12 +352,29 @@ class AlphaFold(nn.Module): ...@@ -269,12 +352,29 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
# s: [*, N, C_s] # s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer( m, z, s = self.evoformer(
m, m,
z, z,
msa_mask=msa_mask.to(dtype=m.dtype), msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -282,12 +382,15 @@ class AlphaFold(nn.Module): ...@@ -282,12 +382,15 @@ class AlphaFold(nn.Module):
outputs["pair"] = z outputs["pair"] = z
outputs["single"] = s outputs["single"] = s
del z
# Predict 3D structure # Predict 3D structure
outputs["sm"] = self.structure_module( outputs["sm"] = self.structure_module(
s, outputs,
z,
feats["aatype"], feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype), mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
) )
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
...@@ -301,7 +404,7 @@ class AlphaFold(nn.Module): ...@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
m_1_prev = m[..., 0, :, :] m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z z_prev = outputs["pair"]
# [*, N, 3] # [*, N, 3]
x_prev = outputs["final_atom_positions"] x_prev = outputs["final_atom_positions"]
...@@ -379,10 +482,9 @@ class AlphaFold(nn.Module): ...@@ -379,10 +482,9 @@ class AlphaFold(nn.Module):
""" """
# Initialize recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop # Main recycling loop
num_iters = batch["aatype"].shape[-1] num_iters = batch["aatype"].shape[-1]
...@@ -395,7 +497,6 @@ class AlphaFold(nn.Module): ...@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
is_final_iter = cycle_no == (num_iters - 1) is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter: if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766) # Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
torch.clear_autocast_cache() torch.clear_autocast_cache()
...@@ -403,12 +504,15 @@ class AlphaFold(nn.Module): ...@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, feats,
m_1_prev, prevs,
z_prev,
x_prev,
_recycle=(num_iters > 1) _recycle=(num_iters > 1)
) )
if(not is_final_iter):
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads # Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
......
...@@ -26,8 +26,8 @@ from openfold.model.primitives import ( ...@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable, _attention_chunked_trainable,
) )
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
) )
...@@ -89,21 +89,38 @@ class MSAAttention(nn.Module): ...@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def _chunk(self, def _chunk(self,
m: torch.Tensor, m: torch.Tensor,
biases: List[torch.Tensor], biases: Optional[List[torch.Tensor]],
use_memory_efficient_kernel: bool,
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
mha = partial( def fn(m, biases, flash_mask):
self.mha, m = self.layer_norm_m(m)
use_memory_efficient_kernel=use_memory_efficient_kernel return self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
) )
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer( return chunk_layer(
mha, fn,
{ inputs,
"q_x": m,
"kv_x": m,
"biases": biases,
},
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2])
) )
...@@ -111,11 +128,9 @@ class MSAAttention(nn.Module): ...@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self, def _prep_inputs(self,
m: torch.Tensor, m: torch.Tensor,
z: Optional[torch.Tensor], z: Optional[torch.Tensor],
mask: Optional[torch.Tensor] mask: Optional[torch.Tensor],
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1] n_seq, n_res = m.shape[-3:-1]
if mask is None: if mask is None:
# [*, N_seq, N_res] # [*, N_seq, N_res]
...@@ -131,11 +146,20 @@ class MSAAttention(nn.Module): ...@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
self.layer_norm_z is not None and # benefit of self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript self.linear_z is not None # TorchScript
): ):
chunks = []
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm_z(z) z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads] # [*, N_res, N_res, no_heads]
z = self.linear_z(z) z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
z = torch.cat(chunks, dim=-3)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
...@@ -149,6 +173,7 @@ class MSAAttention(nn.Module): ...@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor], mask: Optional[torch.Tensor],
chunk_logits: int, chunk_logits: int,
checkpoint: bool, checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
MSA attention with training-time chunking of the softmax computation. MSA attention with training-time chunking of the softmax computation.
...@@ -158,7 +183,10 @@ class MSAAttention(nn.Module): ...@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
MSA_DIM = -4 MSA_DIM = -4
def _get_qkv(m, z): def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask) m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m) q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z return m, q, k, v, mask_bias, z
...@@ -193,6 +221,9 @@ class MSAAttention(nn.Module): ...@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None, _chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None, _checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -214,23 +245,43 @@ class MSAAttention(nn.Module): ...@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None): if(_chunk_logits is not None):
return self._chunked_msa_attn( return self._chunked_msa_attn(
m=m, z=z, mask=mask, m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
) )
m, mask_bias, z = self._prep_inputs(m, z, mask) if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias] biases = [mask_bias]
if(z is not None): if(z is not None):
biases.append(z) biases.append(z)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size) m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
else: else:
m = self.layer_norm_m(m)
m = self.mha( m = self.mha(
q_x=m, q_x=m,
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
) )
return m return m
...@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module): ...@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module): ...@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if mask is not None: if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size) m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
...@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
mha_input = { mha_input = {
"m": m, "m": m,
"mask": mask, "mask": mask,
} }
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer( return chunk_layer(
self.global_attention, fn,
mha_input, mha_input,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
...@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
...@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m) #m = self.layer_norm_m(m)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, mask, chunk_size) m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else: else:
m = self.global_attention(m=m, mask=mask) m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
...@@ -20,7 +20,8 @@ import torch ...@@ -20,7 +20,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module): class OuterProductMean(nn.Module):
...@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module): ...@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
no_batch_dims=1, no_batch_dims=1,
) )
out.append(outer) out.append(outer)
# For some cursed reason making this distinction saves memory
if(len(out) == 1):
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0) outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer return outer
def forward(self, def _forward(self,
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module): ...@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
mask = m.new_ones(m.shape[:-1]) mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm(m) ln = self.layer_norm(m)
# [*, N_seq, N_res, C] # [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask a = self.linear_1(ln)
b = self.linear_2(m) * mask a = a * mask
b = self.linear_2(ln)
b = b * mask
del ln
a = a.transpose(-2, -3) a = a.transpose(-2, -3)
b = b.transpose(-2, -3) b = b.transpose(-2, -3)
...@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module): ...@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1] # [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask) norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
norm = norm + self.eps
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
outer = outer / (self.eps + norm) if(inplace_safe):
outer /= norm
else:
outer = outer / norm
return outer return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module): class PairTransition(nn.Module):
...@@ -46,12 +46,16 @@ class PairTransition(nn.Module): ...@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask): def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden] # [*, N_res, N_res, C_hidden]
z = self.linear_1(z) z = self.linear_1(z)
z = self.relu(z) z = self.relu(z)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask z = self.linear_2(z)
z = z * mask
return z return z
...@@ -68,7 +72,6 @@ class PairTransition(nn.Module): ...@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims=len(z.shape[:-2]), no_batch_dims=len(z.shape[:-2]),
) )
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
...@@ -88,9 +91,6 @@ class PairTransition(nn.Module): ...@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1] # [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None: if chunk_size is not None:
z = self._chunk(z, mask, chunk_size) z = self._chunk(z, mask, chunk_size)
else: else:
......
...@@ -13,24 +13,39 @@ ...@@ -13,24 +13,39 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import importlib
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
import deepspeed deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch import torch
import torch.nn as nn import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
_chunk_slice,
) )
DEFAULT_LMA_Q_CHUNK_SIZE=1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096
def _prod(nums): def _prod(nums):
out = 1 out = 1
for n in nums: for n in nums:
...@@ -145,6 +160,7 @@ class Linear(nn.Linear): ...@@ -145,6 +160,7 @@ class Linear(nn.Linear):
with torch.no_grad(): with torch.no_grad():
self.bias.fill_(0) self.bias.fill_(0)
with torch.no_grad():
if init_fn is not None: if init_fn is not None:
init_fn(self.weight, self.bias) init_fn(self.weight, self.bias)
else: else:
...@@ -157,7 +173,6 @@ class Linear(nn.Linear): ...@@ -157,7 +173,6 @@ class Linear(nn.Linear):
elif init == "gating": elif init == "gating":
gating_init_(self.weight) gating_init_(self.weight)
if bias: if bias:
with torch.no_grad():
self.bias.fill_(1.0) self.bias.fill_(1.0)
elif init == "normal": elif init == "normal":
normal_init_(self.weight) normal_init_(self.weight)
...@@ -179,7 +194,11 @@ class LayerNorm(nn.Module): ...@@ -179,7 +194,11 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
d = x.dtype d = x.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm( out = nn.functional.layer_norm(
x, x,
...@@ -207,7 +226,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: ...@@ -207,7 +226,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
type bfloat16 type bfloat16
""" """
d = t.dtype d = t.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
else: else:
...@@ -403,8 +426,10 @@ class Attention(nn.Module): ...@@ -403,8 +426,10 @@ class Attention(nn.Module):
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
q_chunk_size: Optional[int] = None, lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
kv_chunk_size: Optional[int] = None, lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -423,29 +448,41 @@ class Attention(nn.Module): ...@@ -423,29 +448,41 @@ class Attention(nn.Module):
Whether to use low-memory attention (Staats & Rabe 2021). If Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead implementation is used instead
q_chunk_size: lma_q_chunk_size:
Query chunk size (for LMA) Query chunk size (for LMA)
kv_chunk_size: lma_kv_chunk_size:
Key/Value chunk size (for LMA) Key/Value chunk size (for LMA)
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
if(biases is None): if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
raise ValueError( raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must " "If use_lma is specified, lma_q_chunk_size and "
"be provided" "lma_kv_chunk_size must be provided"
) )
if(use_memory_efficient_kernel and use_lma):
if(use_flash and biases is not None):
raise ValueError(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if(sum(attn_options) > 1):
raise ValueError( raise ValueError(
"Choose one of use_memory_efficient_kernel and use_lma" "Choose at most one alternative attention algorithm"
) )
if(biases is None):
biases = []
# [*, H, Q/K, C_hidden] # [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x)
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
if is_fp16_enabled():
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if(use_memory_efficient_kernel):
if(len(biases) > 2): if(len(biases) > 2):
raise ValueError( raise ValueError(
...@@ -459,7 +496,10 @@ class Attention(nn.Module): ...@@ -459,7 +496,10 @@ class Attention(nn.Module):
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3)
elif(use_flash):
o = _flash_attn(q, k, v, flash_mask)
else: else:
o = _attention(q, k, v, biases) o = _attention(q, k, v, biases)
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
...@@ -494,7 +534,11 @@ class GlobalAttention(nn.Module): ...@@ -494,7 +534,11 @@ class GlobalAttention(nn.Module):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: def forward(self,
m: torch.Tensor,
mask: torch.Tensor,
use_lma: bool = False,
) -> torch.Tensor:
# [*, N_res, C_in] # [*, N_res, C_in]
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps torch.sum(mask, dim=-1)[..., None] + self.eps
...@@ -511,12 +555,13 @@ class GlobalAttention(nn.Module): ...@@ -511,12 +555,13 @@ class GlobalAttention(nn.Module):
k = self.linear_k(m) k = self.linear_k(m)
v = self.linear_v(m) v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma):
# [*, N_res, H, N_seq] # [*, N_res, H, N_seq]
a = torch.matmul( a = torch.matmul(
q, q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
) )
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a += bias
a = softmax_no_cast(a) a = softmax_no_cast(a)
...@@ -525,6 +570,15 @@ class GlobalAttention(nn.Module): ...@@ -525,6 +570,15 @@ class GlobalAttention(nn.Module):
a, a,
v, v,
) )
else:
o = _lma(
q,
k,
v,
[bias],
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden] # [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m)) g = self.sigmoid(self.linear_g(m))
...@@ -552,12 +606,12 @@ def _lma( ...@@ -552,12 +606,12 @@ def _lma(
q_chunk_size: int, q_chunk_size: int,
kv_chunk_size: int, kv_chunk_size: int,
): ):
no_q, no_kv = q.shape[-3], k.shape[-3] no_q, no_kv = q.shape[-2], k.shape[-2]
# [*, Q, H, C_hidden] # [*, H, Q, C_hidden]
o = q.new_zeros(q.shape) o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size): for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :] q_chunk = q[..., q_s: q_s + q_chunk_size, :]
large_bias_chunks = [ large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases b[..., q_s: q_s + q_chunk_size, :] for b in biases
] ]
...@@ -566,24 +620,22 @@ def _lma( ...@@ -566,24 +620,22 @@ def _lma(
weights = [] weights = []
values = [] values = []
for kv_s in range(0, no_kv, kv_chunk_size): for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :] k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :] v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
small_bias_chunks = [ small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
] ]
a = torch.einsum( a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk, "...hqd,...hkd->...hqk", q_chunk, k_chunk,
) )
for b in small_bias_chunks: for b in small_bias_chunks:
a += b a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0] max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a) exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1)) maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1)) weights.append(torch.sum(exp_a, dim=-1))
...@@ -595,14 +647,80 @@ def _lma( ...@@ -595,14 +647,80 @@ def _lma(
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max) max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1) chunk_values = chunk_values * max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs chunk_weights = chunk_weights * max_diffs
all_values = torch.sum(chunk_values, dim=-4) all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights q_chunk_out = all_values / all_weights
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
return o return o
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
q = q.half()
k = k.half()
v = v.half()
kv_mask = kv_mask.half()
# [*, B, N, H, C]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# [B_flat, N, H, C]
q = q.reshape(-1, *q.shape[-3:])
k = k.reshape(-1, *k.shape[-3:])
v = v.reshape(-1, *v.shape[-3:])
# Flattened batch size
batch_size = q.shape[0]
# [B_flat * N, H, C]
q = q.reshape(-1, *q.shape[-2:])
q_max_s = n
q_cu_seqlens = torch.arange(
0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device
)
# [B_flat, N, 2, H, C]
kv = torch.stack([k, v], dim=-3)
kv_shape = kv.shape
# [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1)
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p = 0.,
softmax_scale = 1., # q has been scaled already
)
# [*, B, N, H, C]
out = out.reshape(*batch_dims, n, no_heads, c)
out = out.to(dtype=dtype)
return out
This diff is collapsed.
This diff is collapsed.
...@@ -21,8 +21,8 @@ import torch ...@@ -21,8 +21,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
) )
...@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import ( ...@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class TriangleAttention(nn.Module): class TriangleAttention(nn.Module):
def __init__( def __init__(
self, c_in, c_hidden, no_heads, starting, inf=1e9 self, c_in, c_hidden, no_heads, starting=True, inf=1e9
): ):
""" """
Args: Args:
...@@ -62,23 +62,36 @@ class TriangleAttention(nn.Module): ...@@ -62,23 +62,36 @@ class TriangleAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = { mha_inputs = {
"q_x": x, "q_x": x,
"kv_x": x, "kv_x": x,
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
partial(self.mha), partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
),
mha_inputs, mha_inputs,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]), no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
) )
def forward(self, def forward(self,
x: torch.Tensor, x: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -93,8 +106,7 @@ class TriangleAttention(nn.Module): ...@@ -93,8 +106,7 @@ class TriangleAttention(nn.Module):
x.shape[:-1], x.shape[:-1],
) )
# Shape annotations assume self.starting. Else, I and J are flipped if(not self.starting):
if not self.starting:
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
...@@ -113,27 +125,35 @@ class TriangleAttention(nn.Module): ...@@ -113,27 +125,35 @@ class TriangleAttention(nn.Module):
biases = [mask_bias, triangle_bias] biases = [mask_bias, triangle_bias]
if chunk_size is not None: if chunk_size is not None:
x = self._chunk(x, biases, chunk_size) x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else: else:
x = self.mha(q_x=x, kv_x=x, biases=biases) x = self.mha(
q_x=x,
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
)
if not self.starting: if(not self.starting):
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
return x return x
class TriangleAttentionStartingNode(TriangleAttention): # Implements Algorithm 13
""" TriangleAttentionStartingNode = TriangleAttention
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
class TriangleAttentionEndingNode(TriangleAttention): class TriangleAttentionEndingNode(TriangleAttention):
""" """
Implements Algorithm 14. Implements Algorithm 14.
""" """
__init__ = partialmethod(TriangleAttention.__init__, starting=False) __init__ = partialmethod(TriangleAttention.__init__, starting=False)
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
...@@ -28,10 +28,18 @@ import openfold.utils.loss as loss ...@@ -28,10 +28,18 @@ import openfold.utils.loss as loss
from openfold.np.relax import cleanup, utils from openfold.np.relax import cleanup, utils
import ml_collections import ml_collections
import numpy as np import numpy as np
from simtk import openmm try:
from simtk import unit # openmm >= 7.6
from simtk.openmm import app as openmm_app import openmm
from simtk.openmm.app.internal.pdbstructure import PdbStructure from openmm import unit
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms LENGTH = unit.angstroms
...@@ -192,6 +200,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True): ...@@ -192,6 +200,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions()) pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks: if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string) _check_cleaned_atoms(pdb_string, prot_pdb_string)
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
pdb_string = '\n'.join(['\n'.join(headers), pdb_string])
return pdb_string return pdb_string
...@@ -511,6 +524,9 @@ def run_pipeline( ...@@ -511,6 +524,9 @@ def run_pipeline(
_check_residues_are_well_defined(prot) _check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks) pdb_string = clean_protein(prot, checks=checks)
# We keep the input around to restore metadata deleted by the relaxer
input_prot = prot
exclude_residues = exclude_residues or [] exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues) exclude_residues = set(exclude_residues)
violations = np.inf violations = np.inf
...@@ -527,6 +543,11 @@ def run_pipeline( ...@@ -527,6 +543,11 @@ def run_pipeline(
max_attempts=max_attempts, max_attempts=max_attempts,
use_gpu=use_gpu, use_gpu=use_gpu,
) )
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
ret["min_pdb"] = '\n'.join(['\n'.join(headers), ret["min_pdb"]])
prot = protein.from_pdb_string(ret["min_pdb"]) prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration: if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True) pdb_string = clean_protein(prot, checks=True)
......
...@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure). ...@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
import io import io
import pdbfixer import pdbfixer
from simtk.openmm import app try:
from simtk.openmm.app import element # openmm >= 7.6
from openmm import app
from openmm.app import element
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info): def fix_pdb(pdbfile, alterations_info):
......
...@@ -87,4 +87,7 @@ class AmberRelaxation(object): ...@@ -87,4 +87,7 @@ class AmberRelaxation(object):
violations = out["structural_violations"][ violations = out["structural_violations"][
"total_per_residue_violations_mask" "total_per_residue_violations_mask"
] ]
min_pdb = protein.add_pdb_headers(prot, min_pdb)
return min_pdb, debug_data, violations return min_pdb, debug_data, violations
...@@ -18,8 +18,14 @@ import io ...@@ -18,8 +18,14 @@ import io
from openfold.np import residue_constants from openfold.np import residue_constants
from Bio import PDB from Bio import PDB
import numpy as np import numpy as np
from simtk.openmm import app as openmm_app try:
from simtk.openmm.app.internal.pdbstructure import PdbStructure # openmm >= 7.6
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
......
...@@ -1120,10 +1120,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation): ...@@ -1120,10 +1120,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions # and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the # and compute affine transformation matrices (4,4) from one rigid group to the
# previous group # previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int) restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int) restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
...@@ -1279,7 +1279,7 @@ def make_atom14_dists_bounds( ...@@ -1279,7 +1279,7 @@ def make_atom14_dists_bounds(
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile( restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1) np.arange(14, dtype=int), (21, 1)
) )
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment