Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
This diff is collapsed.
......@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm,
compute_predicted_aligned_error,
)
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module):
......@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z]
def _forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
......@@ -150,6 +151,13 @@ class DistogramHead(nn.Module):
logits = logits + logits.transpose(-2, -3)
return logits
def forward(self, z):
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
class TMScoreHead(nn.Module):
"""
......
......@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import weakref
import torch
import torch.nn as nn
......@@ -34,12 +35,26 @@ from openfold.model.embedders import (
)
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
embed_templates_average,
embed_templates_offload,
)
import openfold.np.residue_constants as residue_constants
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from openfold.utils.loss import (
compute_plddt,
)
from openfold.utils.tensor_utils import (
add,
dict_multimap,
tensor_tree_map,
)
......@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa
self.config = config.model
self.template_config = self.config.template
self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module
if(self.globals.is_multimer):
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
**self.config["input_embedder"],
)
else:
self.input_embedder = InputEmbedder(
**config["input_embedder"],
**self.config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
**self.config["recycling_embedder"],
)
if (self.template_config.enabled):
if(self.globals.is_multimer):
self.template_embedder = TemplateEmbedderMultimer(
template_config,
self.template_config,
)
else:
self.template_embedder = TemplateEmbedder(
template_config,
self.template_config,
)
if (self.extra_msa_config.enabled):
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
**self.extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"],
**self.extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
**config["evoformer_stack"],
**self.config["evoformer_stack"],
)
self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**config["structure_module"],
**self.config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
config["heads"],
self.config["heads"],
)
def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if (self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if (self.template_config.offload_templates):
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif (self.template_config.average_templates):
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
self.config = config
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
return template_embeds
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary
outputs = {}
......@@ -126,17 +182,36 @@ class AlphaFold(nn.Module):
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
## Initialize the MSA and pair representations
if (self.globals.is_multimer):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
else:
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
......@@ -161,30 +236,33 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
inplace_safe=inplace_safe,
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
z = add(z, z_prev_emb, inplace=inplace_safe)
# Possibly prevents memory fragmentation
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
......@@ -193,37 +271,23 @@ class AlphaFold(nn.Module):
k: v for k, v in feats.items() if k.startswith("template_")
}
if(self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_embeds = self.embed_templates(
template_feats,
feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size
inplace_safe=inplace_safe,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
z = add(z,
template_embeds.pop("template_pair_embedding"),
inplace_safe,
)
if(
self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled)
"template_single_embedding" in template_embeds
):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
......@@ -253,15 +317,34 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
extra_msa_feat,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z.dtype),
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
......@@ -269,12 +352,29 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
......@@ -282,12 +382,15 @@ class AlphaFold(nn.Module):
outputs["pair"] = z
outputs["single"] = s
del z
# Predict 3D structure
outputs["sm"] = self.structure_module(
s,
z,
outputs,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
......@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = z
z_prev = outputs["pair"]
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
......@@ -379,10 +482,9 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
......@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
......@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
m_1_prev,
z_prev,
x_prev,
prevs,
_recycle=(num_iters > 1)
)
if(not is_final_iter):
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
......@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
use_memory_efficient_kernel: bool,
biases: Optional[List[torch.Tensor]],
chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor:
mha = partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel
def fn(m, biases, flash_mask):
m = self.layer_norm_m(m)
return self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
)
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer(
mha,
{
"q_x": m,
"kv_x": m,
"biases": biases,
},
fn,
inputs,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2])
)
......@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
mask: Optional[torch.Tensor],
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
......@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
chunks = []
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z = self.linear_z(z)
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
z = torch.cat(chunks, dim=-3)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
......@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
......@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
......@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
......@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size)
m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
else:
m = self.layer_norm_m(m)
m = self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
return m
......@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer(
self.global_attention,
fn,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
......@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
......@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m)
#m = self.layer_norm_m(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else:
m = self.global_attention(m=m, mask=mask)
m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......
......@@ -20,7 +20,8 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module):
......@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
no_batch_dims=1,
)
out.append(outer)
# For some cursed reason making this distinction saves memory
if(len(out) == 1):
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer
def forward(self,
def _forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m]
m = self.layer_norm(m)
ln = self.layer_norm(m)
# [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask
b = self.linear_2(m) * mask
a = self.linear_1(ln)
a = a * mask
b = self.linear_2(ln)
b = b * mask
del ln
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
......@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
norm = norm + self.eps
# [*, N_res, N_res, C_z]
outer = outer / (self.eps + norm)
if(inplace_safe):
outer /= norm
else:
outer = outer / norm
return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module):
......@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)
# [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask
z = self.linear_2(z)
z = z * mask
return z
......@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
......@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None:
z = self._chunk(z, mask, chunk_size)
else:
......
......@@ -13,24 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import importlib
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import deepspeed
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch
import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
_chunk_slice,
)
DEFAULT_LMA_Q_CHUNK_SIZE=1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096
def _prod(nums):
out = 1
for n in nums:
......@@ -145,6 +160,7 @@ class Linear(nn.Linear):
with torch.no_grad():
self.bias.fill_(0)
with torch.no_grad():
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
......@@ -157,7 +173,6 @@ class Linear(nn.Linear):
elif init == "gating":
gating_init_(self.weight)
if bias:
with torch.no_grad():
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
......@@ -179,7 +194,11 @@ class LayerNorm(nn.Module):
def forward(self, x):
d = x.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
......@@ -207,7 +226,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
type bfloat16
"""
d = t.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
......@@ -403,8 +426,10 @@ class Attention(nn.Module):
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
......@@ -423,29 +448,41 @@ class Attention(nn.Module):
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
q_chunk_size:
lma_q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
lma_kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if(biases is None):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)):
raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
"If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided"
)
if(use_memory_efficient_kernel and use_lma):
if(use_flash and biases is not None):
raise ValueError(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if(sum(attn_options) > 1):
raise ValueError(
"Choose one of use_memory_efficient_kernel and use_lma"
"Choose at most one alternative attention algorithm"
)
if(biases is None):
biases = []
# [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x)
# [*, Q, H, C_hidden]
if is_fp16_enabled():
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel):
if(len(biases) > 2):
raise ValueError(
......@@ -459,7 +496,10 @@ class Attention(nn.Module):
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3)
elif(use_flash):
o = _flash_attn(q, k, v, flash_mask)
else:
o = _attention(q, k, v, biases)
o = o.transpose(-2, -3)
......@@ -494,7 +534,11 @@ class GlobalAttention(nn.Module):
self.sigmoid = nn.Sigmoid()
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def forward(self,
m: torch.Tensor,
mask: torch.Tensor,
use_lma: bool = False,
) -> torch.Tensor:
# [*, N_res, C_in]
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
......@@ -511,12 +555,13 @@ class GlobalAttention(nn.Module):
k = self.linear_k(m)
v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma):
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = softmax_no_cast(a)
......@@ -525,6 +570,15 @@ class GlobalAttention(nn.Module):
a,
v,
)
else:
o = _lma(
q,
k,
v,
[bias],
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
......@@ -552,12 +606,12 @@ def _lma(
q_chunk_size: int,
kv_chunk_size: int,
):
no_q, no_kv = q.shape[-3], k.shape[-3]
no_q, no_kv = q.shape[-2], k.shape[-2]
# [*, Q, H, C_hidden]
# [*, H, Q, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
q_chunk = q[..., q_s: q_s + q_chunk_size, :]
large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]
......@@ -566,24 +620,22 @@ def _lma(
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
]
a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk,
"...hqd,...hkd->...hqk", q_chunk, k_chunk,
)
for b in small_bias_chunks:
a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
......@@ -595,14 +647,80 @@ def _lma(
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs
chunk_values = chunk_values * max_diffs.unsqueeze(-1)
chunk_weights = chunk_weights * max_diffs
all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
return o
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
q = q.half()
k = k.half()
v = v.half()
kv_mask = kv_mask.half()
# [*, B, N, H, C]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# [B_flat, N, H, C]
q = q.reshape(-1, *q.shape[-3:])
k = k.reshape(-1, *k.shape[-3:])
v = v.reshape(-1, *v.shape[-3:])
# Flattened batch size
batch_size = q.shape[0]
# [B_flat * N, H, C]
q = q.reshape(-1, *q.shape[-2:])
q_max_s = n
q_cu_seqlens = torch.arange(
0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device
)
# [B_flat, N, 2, H, C]
kv = torch.stack([k, v], dim=-3)
kv_shape = kv.shape
# [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1)
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p = 0.,
softmax_scale = 1., # q has been scaled already
)
# [*, B, N, H, C]
out = out.reshape(*batch_dims, n, no_heads, c)
out = out.to(dtype=dtype)
return out
This diff is collapsed.
This diff is collapsed.
......@@ -21,8 +21,8 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
......@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class TriangleAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, starting, inf=1e9
self, c_in, c_hidden, no_heads, starting=True, inf=1e9
):
"""
Args:
......@@ -62,23 +62,36 @@ class TriangleAttention(nn.Module):
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
partial(self.mha),
partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
)
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -93,8 +106,7 @@ class TriangleAttention(nn.Module):
x.shape[:-1],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if not self.starting:
if(not self.starting):
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
......@@ -113,27 +125,35 @@ class TriangleAttention(nn.Module):
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size)
x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else:
x = self.mha(q_x=x, kv_x=x, biases=biases)
x = self.mha(
q_x=x,
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
)
if not self.starting:
if(not self.starting):
x = x.transpose(-2, -3)
return x
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
# Implements Algorithm 13
TriangleAttentionStartingNode = TriangleAttention
class TriangleAttentionEndingNode(TriangleAttention):
"""
Implements Algorithm 14.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
......@@ -28,10 +28,18 @@ import openfold.utils.loss as loss
from openfold.np.relax import cleanup, utils
import ml_collections
import numpy as np
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
try:
# openmm >= 7.6
import openmm
from openmm import unit
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms
......@@ -192,6 +200,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
pdb_string = '\n'.join(['\n'.join(headers), pdb_string])
return pdb_string
......@@ -511,6 +524,9 @@ def run_pipeline(
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)
# We keep the input around to restore metadata deleted by the relaxer
input_prot = prot
exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
......@@ -527,6 +543,11 @@ def run_pipeline(
max_attempts=max_attempts,
use_gpu=use_gpu,
)
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
ret["min_pdb"] = '\n'.join(['\n'.join(headers), ret["min_pdb"]])
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
......
......@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
try:
# openmm >= 7.6
from openmm import app
from openmm.app import element
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
......
......@@ -87,4 +87,7 @@ class AmberRelaxation(object):
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
min_pdb = protein.add_pdb_headers(prot, min_pdb)
return min_pdb, debug_data, violations
......@@ -18,8 +18,14 @@ import io
from openfold.np import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
try:
# openmm >= 7.6
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
......
......@@ -1120,10 +1120,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
......@@ -1279,7 +1279,7 @@ def make_atom14_dists_bounds(
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1)
np.arange(14, dtype=int), (21, 1)
)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment