Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
This diff is collapsed.
......@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm,
compute_predicted_aligned_error,
)
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module):
......@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z]
def _forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
......@@ -149,6 +150,13 @@ class DistogramHead(nn.Module):
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits
def forward(self, z):
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
class TMScoreHead(nn.Module):
......
......@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import weakref
import torch
import torch.nn as nn
......@@ -34,12 +35,26 @@ from openfold.model.embedders import (
)
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
embed_templates_average,
embed_templates_offload,
)
import openfold.np.residue_constants as residue_constants
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from openfold.utils.loss import (
compute_plddt,
)
from openfold.utils.tensor_utils import (
add,
dict_multimap,
tensor_tree_map,
)
......@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa
self.config = config.model
self.template_config = self.config.template
self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module
if(self.globals.is_multimer):
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
**self.config["input_embedder"],
)
else:
self.input_embedder = InputEmbedder(
**config["input_embedder"],
**self.config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
**self.config["recycling_embedder"],
)
if(self.globals.is_multimer):
self.template_embedder = TemplateEmbedderMultimer(
template_config,
if (self.template_config.enabled):
if(self.globals.is_multimer):
self.template_embedder = TemplateEmbedderMultimer(
self.template_config,
)
else:
self.template_embedder = TemplateEmbedder(
self.template_config,
)
if (self.extra_msa_config.enabled):
self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"],
)
else:
self.template_embedder = TemplateEmbedder(
template_config,
self.extra_msa_stack = ExtraMSAStack(
**self.extra_msa_config["extra_msa_stack"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
**config["evoformer_stack"],
**self.config["evoformer_stack"],
)
self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**config["structure_module"],
**self.config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
config["heads"],
self.config["heads"],
)
self.config = config
def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if (self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if (self.template_config.offload_templates):
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif (self.template_config.average_templates):
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
return template_embeds
def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary
outputs = {}
......@@ -125,19 +181,38 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
## Initialize the MSA and pair representations
# Initialize the MSA and pair representations
if (self.globals.is_multimer):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
else:
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Initialize the recycling embeddings, if needs be
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
......@@ -161,69 +236,58 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
inplace_safe=inplace_safe,
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
z = add(z, z_prev_emb, inplace=inplace_safe)
# Possibly prevents memory fragmentation
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
if(self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size
)
template_embeds = self.embed_templates(
template_feats,
feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
inplace_safe=inplace_safe,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
z = add(z,
template_embeds.pop("template_pair_embedding"),
inplace_safe,
)
if(
self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled)
"template_single_embedding" in template_embeds
):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
......@@ -253,41 +317,80 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# [*, N, N, C_z]
z = self.extra_msa_stack(
extra_msa_feat,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
del z
# Predict 3D structure
outputs["sm"] = self.structure_module(
s,
z,
outputs,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
......@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = z
z_prev = outputs["pair"]
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
......@@ -379,14 +482,13 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
......@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
......@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
m_1_prev,
z_prev,
x_prev,
prevs,
_recycle=(num_iters > 1)
)
if(not is_final_iter):
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
......@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
use_memory_efficient_kernel: bool,
biases: Optional[List[torch.Tensor]],
chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor:
mha = partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel
)
def fn(m, biases, flash_mask):
m = self.layer_norm_m(m)
return self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
)
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer(
mha,
{
"q_x": m,
"kv_x": m,
"biases": biases,
},
fn,
inputs,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2])
)
......@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
mask: Optional[torch.Tensor],
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
......@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
chunks = []
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z]
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
# [*, N_res, N_res, no_heads]
z = self.linear_z(z)
z = torch.cat(chunks, dim=-3)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
......@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
......@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
......@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
......@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
if(z is not None):
biases.append(z)
chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size)
m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
else:
m = self.layer_norm_m(m)
m = self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
return m
......@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer(
self.global_attention,
fn,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
......@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
......@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m)
#m = self.layer_norm_m(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else:
m = self.global_attention(m=m, mask=mask)
m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......
......@@ -20,7 +20,8 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module):
......@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
no_batch_dims=1,
)
out.append(outer)
outer = torch.stack(out, dim=0)
# For some cursed reason making this distinction saves memory
if(len(out) == 1):
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer
def forward(self,
def _forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m]
m = self.layer_norm(m)
ln = self.layer_norm(m)
# [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask
b = self.linear_2(m) * mask
a = self.linear_1(ln)
a = a * mask
b = self.linear_2(ln)
b = b * mask
del ln
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
......@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
norm = norm + self.eps
# [*, N_res, N_res, C_z]
outer = outer / (self.eps + norm)
if(inplace_safe):
outer /= norm
else:
outer = outer / norm
return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module):
......@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)
# [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask
z = self.linear_2(z)
z = z * mask
return z
......@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
......@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None:
z = self._chunk(z, mask, chunk_size)
else:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
......@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
try:
# openmm >= 7.6
from openmm import app
from openmm.app import element
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
......
......@@ -87,4 +87,7 @@ class AmberRelaxation(object):
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
min_pdb = protein.add_pdb_headers(prot, min_pdb)
return min_pdb, debug_data, violations
......@@ -18,8 +18,14 @@ import io
from openfold.np import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
try:
# openmm >= 7.6
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment