Unverified Commit 9c0e7519 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

Multimer (#57)

* add import _weight

* add struct mudule
parent ea7a6584
...@@ -17,6 +17,7 @@ from functools import partial ...@@ -17,6 +17,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from fastfold.data import data_transforms_multimer
from fastfold.utils.feats import ( from fastfold.utils.feats import (
pseudo_beta_fn, pseudo_beta_fn,
build_extra_msa_feat, build_extra_msa_feat,
...@@ -27,8 +28,7 @@ from fastfold.utils.feats import ( ...@@ -27,8 +28,7 @@ from fastfold.utils.feats import (
from fastfold.model.nn.embedders import ( from fastfold.model.nn.embedders import (
InputEmbedder, InputEmbedder,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateEmbedder,
TemplatePairEmbedder,
ExtraMSAEmbedder, ExtraMSAEmbedder,
) )
from fastfold.model.nn.embedders_multimer import TemplateEmbedderMultimer, InputEmbedderMultimer from fastfold.model.nn.embedders_multimer import TemplateEmbedderMultimer, InputEmbedderMultimer
...@@ -36,10 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack ...@@ -36,10 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from fastfold.model.nn.heads import AuxiliaryHeads from fastfold.model.nn.heads import AuxiliaryHeads
import fastfold.common.residue_constants as residue_constants import fastfold.common.residue_constants as residue_constants
from fastfold.model.nn.structure_module import StructureModule from fastfold.model.nn.structure_module import StructureModule
from fastfold.model.nn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.model.loss import ( from fastfold.model.loss import (
compute_plddt, compute_plddt,
) )
...@@ -81,24 +77,13 @@ class AlphaFold(nn.Module): ...@@ -81,24 +77,13 @@ class AlphaFold(nn.Module):
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
**config["input_embedder"], **config["input_embedder"],
) )
self.template_angle_embedder = TemplateAngleEmbedder( self.template_embedder = TemplateEmbedder(
**template_config["template_angle_embedder"], template_config,
)
self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"],
) )
self.recycling_embedder = RecyclingEmbedder( self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"], **config["recycling_embedder"],
) )
self.extra_msa_embedder = ExtraMSAEmbedder( self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"], **extra_msa_config["extra_msa_embedder"],
) )
...@@ -210,11 +195,15 @@ class AlphaFold(nn.Module): ...@@ -210,11 +195,15 @@ class AlphaFold(nn.Module):
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder( m, z = (
self.input_embedder(
feats["target_feat"], feats["target_feat"],
feats["residue_index"], feats["residue_index"],
feats["msa_feat"], feats["msa_feat"],
) )
if not self.globals.is_multimer
else self.input_embedder(feats)
)
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
...@@ -236,9 +225,8 @@ class AlphaFold(nn.Module): ...@@ -236,9 +225,8 @@ class AlphaFold(nn.Module):
requires_grad=False, requires_grad=False,
) )
x_prev = pseudo_beta_fn( x_prev, _ = pseudo_beta_fn(feats["aatype"], x_prev, None)
feats["aatype"], x_prev, None x_prev = x_prev.to(dtype=z.dtype)
).to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
...@@ -270,40 +258,72 @@ class AlphaFold(nn.Module): ...@@ -270,40 +258,72 @@ class AlphaFold(nn.Module):
template_feats = { template_feats = {
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
template_embeds = self.embed_templates(
if self.globals.is_multimer:
asym_id = feats["asym_id"]
multichain_mask_2d = asym_id[..., None] == asym_id[..., None, :]
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_feats, template_feats,
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
no_batch_dims, no_batch_dims,
self.globals.chunk_size
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles: if(
self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled)
):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_single_embedding"]],
dim=-3 dim=-3
) )
# [*, S, N] # [*, S, N]
if(not self.globals.is_multimer):
torsion_angles_mask = feats["template_torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], [feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2 dim=-2
) )
else:
msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]],
dim=-2,
)
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
if(self.globals.is_multimer):
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else:
extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.extra_msa_stack( z = self.extra_msa_stack(
a, extra_msa_feat,
z, z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
...@@ -353,14 +373,14 @@ class AlphaFold(nn.Module): ...@@ -353,14 +373,14 @@ class AlphaFold(nn.Module):
return outputs, m_1_prev, z_prev, x_prev return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self): def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None self.template_embedder.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks: for b in self.extra_msa_stack.blocks:
b.ckpt = False b.ckpt = False
def _enable_activation_checkpointing(self): def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = ( self.template_embedder.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt self.config.template.template_pair_stack.blocks_per_ckpt
) )
self.evoformer.blocks_per_ckpt = ( self.evoformer.blocks_per_ckpt = (
......
...@@ -17,9 +17,20 @@ import torch ...@@ -17,9 +17,20 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Dict from typing import Tuple, Dict
from functools import partial
from fastfold.utils import all_atom_multimer
from fastfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot from fastfold.utils.tensor_utils import one_hot
from fastfold.model.nn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils import geometry
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
""" """
...@@ -221,6 +232,97 @@ class RecyclingEmbedder(nn.Module): ...@@ -221,6 +232,97 @@ class RecyclingEmbedder(nn.Module):
return m_update, z_update return m_update, z_update
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
class TemplateAngleEmbedder(nn.Module): class TemplateAngleEmbedder(nn.Module):
""" """
......
...@@ -84,7 +84,6 @@ class MSATransition(nn.Module): ...@@ -84,7 +84,6 @@ class MSATransition(nn.Module):
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
def forward( def forward(
self, self,
m: torch.Tensor, m: torch.Tensor,
...@@ -101,10 +100,12 @@ class MSATransition(nn.Module): ...@@ -101,10 +100,12 @@ class MSATransition(nn.Module):
m: m:
[*, N_seq, N_res, C_m] MSA activation update [*, N_seq, N_res, C_m] MSA activation update
""" """
# DISCREPANCY: DeepMind forgets to apply the MSA mask here. # DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if mask is None: if mask is None:
mask = m.new_ones(m.shape[:-1]) mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, 1]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
m = self.layer_norm(m) m = self.layer_norm(m)
...@@ -132,9 +133,10 @@ class EvoformerBlockCore(nn.Module): ...@@ -132,9 +133,10 @@ class EvoformerBlockCore(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
is_multimer: bool = False,
): ):
super(EvoformerBlockCore, self).__init__() super(EvoformerBlockCore, self).__init__()
self.is_multimer = is_multimer
self.msa_transition = MSATransition( self.msa_transition = MSATransition(
c_m=c_m, c_m=c_m,
n=transition_n, n=transition_n,
...@@ -261,6 +263,12 @@ class EvoformerBlock(nn.Module): ...@@ -261,6 +263,12 @@ class EvoformerBlock(nn.Module):
eps=eps, eps=eps,
) )
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
from fastfold.model.nn.primitives import Linear, LayerNorm, ipa_point_weights_init_ from fastfold.model.nn.primitives import Linear, LayerNorm, ipa_point_weights_init_
from fastfold.common.residue_constants import ( from fastfold.common.residue_constants import (
...@@ -73,7 +73,9 @@ class AngleResnet(nn.Module): ...@@ -73,7 +73,9 @@ class AngleResnet(nn.Module):
Implements Algorithm 20, lines 11-14 Implements Algorithm 20, lines 11-14
""" """
def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon): def __init__(
self, c_in: int, c_hidden: int, no_blocks: int, no_angles: int, epsilon: float
):
""" """
Args: Args:
c_in: c_in:
...@@ -145,7 +147,7 @@ class AngleResnet(nn.Module): ...@@ -145,7 +147,7 @@ class AngleResnet(nn.Module):
unnormalized_s = s unnormalized_s = s
norm_denom = torch.sqrt( norm_denom = torch.sqrt(
torch.clamp( torch.clamp(
torch.sum(s ** 2, dim=-1, keepdim=True), torch.sum(s**2, dim=-1, keepdim=True),
min=self.eps, min=self.eps,
) )
) )
...@@ -153,6 +155,7 @@ class AngleResnet(nn.Module): ...@@ -153,6 +155,7 @@ class AngleResnet(nn.Module):
return unnormalized_s, s return unnormalized_s, s
class PointProjection(nn.Module): class PointProjection(nn.Module):
def __init__( def __init__(
self, self,
...@@ -491,7 +494,7 @@ class BackboneUpdate(nn.Module): ...@@ -491,7 +494,7 @@ class BackboneUpdate(nn.Module):
Implements part of Algorithm 23. Implements part of Algorithm 23.
""" """
def __init__(self, c_s): def __init__(self, c_s: int):
""" """
Args: Args:
c_s: c_s:
...@@ -517,7 +520,7 @@ class BackboneUpdate(nn.Module): ...@@ -517,7 +520,7 @@ class BackboneUpdate(nn.Module):
class StructureModuleTransitionLayer(nn.Module): class StructureModuleTransitionLayer(nn.Module):
def __init__(self, c): def __init__(self, c: int):
super(StructureModuleTransitionLayer, self).__init__() super(StructureModuleTransitionLayer, self).__init__()
self.c = c self.c = c
...@@ -528,7 +531,7 @@ class StructureModuleTransitionLayer(nn.Module): ...@@ -528,7 +531,7 @@ class StructureModuleTransitionLayer(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, s): def forward(self, s: torch.Tensor):
s_initial = s s_initial = s
s = self.linear_1(s) s = self.linear_1(s)
s = self.relu(s) s = self.relu(s)
...@@ -542,7 +545,7 @@ class StructureModuleTransitionLayer(nn.Module): ...@@ -542,7 +545,7 @@ class StructureModuleTransitionLayer(nn.Module):
class StructureModuleTransition(nn.Module): class StructureModuleTransition(nn.Module):
def __init__(self, c, num_layers, dropout_rate): def __init__(self, c: int, num_layers: int, dropout_rate: float):
super(StructureModuleTransition, self).__init__() super(StructureModuleTransition, self).__init__()
self.c = c self.c = c
...@@ -557,7 +560,7 @@ class StructureModuleTransition(nn.Module): ...@@ -557,7 +560,7 @@ class StructureModuleTransition(nn.Module):
self.dropout = nn.Dropout(self.dropout_rate) self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = LayerNorm(self.c) self.layer_norm = LayerNorm(self.c)
def forward(self, s): def forward(self, s: torch.Tensor) -> torch.Tensor:
for l in self.layers: for l in self.layers:
s = l(s) s = l(s)
...@@ -570,22 +573,22 @@ class StructureModuleTransition(nn.Module): ...@@ -570,22 +573,22 @@ class StructureModuleTransition(nn.Module):
class StructureModule(nn.Module): class StructureModule(nn.Module):
def __init__( def __init__(
self, self,
c_s, c_s: int,
c_z, c_z: int,
c_ipa, c_ipa: int,
c_resnet, c_resnet: int,
no_heads_ipa, no_heads_ipa: int,
no_qk_points, no_qk_points: int,
no_v_points, no_v_points: int,
dropout_rate, dropout_rate: float,
no_blocks, no_blocks: int,
no_transition_layers, no_transition_layers: int,
no_resnet_blocks, no_resnet_blocks: int,
no_angles, no_angles: int,
trans_scale_factor, trans_scale_factor: float,
epsilon, epsilon: float,
inf, inf: float,
is_multimer=False, is_multimer: bool = False,
**kwargs, **kwargs,
): ):
""" """
...@@ -621,6 +624,8 @@ class StructureModule(nn.Module): ...@@ -621,6 +624,8 @@ class StructureModule(nn.Module):
Small number used in angle resnet normalization Small number used in angle resnet normalization
inf: inf:
Large number used for attention masking Large number used for attention masking
is_multimer:
whether running under multimer mode
""" """
super(StructureModule, self).__init__() super(StructureModule, self).__init__()
...@@ -673,6 +678,9 @@ class StructureModule(nn.Module): ...@@ -673,6 +678,9 @@ class StructureModule(nn.Module):
self.dropout_rate, self.dropout_rate,
) )
if is_multimer:
self.bb_update = QuatRigid(self.c_s, full_quat=False)
else:
self.bb_update = BackboneUpdate(self.c_s) self.bb_update = BackboneUpdate(self.c_s)
self.angle_resnet = AngleResnet( self.angle_resnet = AngleResnet(
...@@ -683,13 +691,13 @@ class StructureModule(nn.Module): ...@@ -683,13 +691,13 @@ class StructureModule(nn.Module):
self.epsilon, self.epsilon,
) )
def forward( def _forward_monomer(
self, self,
s, s: torch.Tensor,
z, z: torch.Tensor,
aatype, aatype: torch.Tensor,
mask=None, mask: Optional[torch.Tensor] = None,
): ) -> Dict[str, Any]:
""" """
Args: Args:
s: s:
...@@ -785,7 +793,103 @@ class StructureModule(nn.Module): ...@@ -785,7 +793,103 @@ class StructureModule(nn.Module):
return outputs return outputs
def _init_residue_constants(self, float_dtype, device): def _forward_multimer(
self,
s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
# [*, N, C_s]
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(z)
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
# [*, N]
rigids = Rigid3Array.identity(
s.shape[:-1],
s.device,
)
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids @ self.bb_update(s)
# [*, N, 7, 2]
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
rigids.scale_translation(self.trans_scale_factor),
angles,
aatype,
)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
aatype,
)
preds = {
"frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz.to_tensor(),
}
outputs.append(preds)
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if self.is_multimer:
outputs = self._forward_multimer(s, z, aatype, mask)
else:
outputs = self._forward_monomer(s, z, aatype, mask)
return outputs
def _init_residue_constants(self, float_dtype: torch.dtype, device: torch.device):
if self.default_frames is None: if self.default_frames is None:
self.default_frames = torch.tensor( self.default_frames = torch.tensor(
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
...@@ -814,17 +918,24 @@ class StructureModule(nn.Module): ...@@ -814,17 +918,24 @@ class StructureModule(nn.Module):
requires_grad=False, requires_grad=False,
) )
def torsion_angles_to_frames(self, r, alpha, f): def torsion_angles_to_frames(
self, r: Union[Rigid, Rigid3Array], alpha: torch.Tensor, f
):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device) self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames) return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
self, r, f # [*, N, 8] # [*, N] self, r: Union[Rigid, Rigid3Array], f # [*, N, 8] # [*, N]
): ):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
if type(r) == Rigid:
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
elif type(r) == Rigid3Array:
self._init_residue_constants(r.dtype, r.device)
else:
raise ValueError("Unknown rigid type")
return frames_and_literature_positions_to_atom14_pos( return frames_and_literature_positions_to_atom14_pos(
r, r,
f, f,
......
...@@ -18,10 +18,12 @@ import math ...@@ -18,10 +18,12 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict from typing import Any, Dict, Optional, Tuple, Union
from fastfold.common import protein from fastfold.common import protein
import fastfold.common.residue_constants as rc import fastfold.common.residue_constants as rc
from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from fastfold.utils.geometry.rotation_matrix import Rot3Array
from fastfold.utils.rigid_utils import Rotation, Rigid from fastfold.utils.rigid_utils import Rotation, Rigid
from fastfold.utils.tensor_utils import ( from fastfold.utils.tensor_utils import (
batched_gather, batched_gather,
...@@ -36,7 +38,7 @@ def dgram_from_positions( ...@@ -36,7 +38,7 @@ def dgram_from_positions(
max_bin: float = 50.75, max_bin: float = 50.75,
no_bins: float = 39, no_bins: float = 39,
inf: float = 1e8, inf: float = 1e8,
): ) -> torch.Tensor:
dgram = torch.sum( dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True (pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
) )
...@@ -46,8 +48,9 @@ def dgram_from_positions( ...@@ -46,8 +48,9 @@ def dgram_from_positions(
return dgram return dgram
def pseudo_beta_fn(
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): aatype, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
is_gly = aatype == rc.restype_order["G"] is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"] cb_idx = rc.atom_order["CB"]
...@@ -65,10 +68,10 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): ...@@ -65,10 +68,10 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
) )
return pseudo_beta, pseudo_beta_mask return pseudo_beta, pseudo_beta_mask
else: else:
return pseudo_beta return pseudo_beta, None
def atom14_to_atom37(atom14, batch): def atom14_to_atom37(atom14, batch: Dict[str, Any]):
atom37_data = batched_gather( atom37_data = batched_gather(
atom14, atom14,
batch["residx_atom37_to_atom14"], batch["residx_atom37_to_atom14"],
...@@ -81,19 +84,15 @@ def atom14_to_atom37(atom14, batch): ...@@ -81,19 +84,15 @@ def atom14_to_atom37(atom14, batch):
return atom37_data return atom37_data
def build_template_angle_feat(template_feats): def build_template_angle_feat(template_feats: Dict[str, Any]) -> torch.Tensor:
template_aatype = template_feats["template_aatype"] template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats[ alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
"template_alt_torsion_angles_sin_cos"
]
torsion_angles_mask = template_feats["template_torsion_angles_mask"] torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat( template_angle_feat = torch.cat(
[ [
nn.functional.one_hot(template_aatype, 22), nn.functional.one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape( torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
*torsion_angles_sin_cos.shape[:-2], 14
),
alt_torsion_angles_sin_cos.reshape( alt_torsion_angles_sin_cos.reshape(
*alt_torsion_angles_sin_cos.shape[:-2], 14 *alt_torsion_angles_sin_cos.shape[:-2], 14
), ),
...@@ -106,22 +105,20 @@ def build_template_angle_feat(template_feats): ...@@ -106,22 +105,20 @@ def build_template_angle_feat(template_feats):
def build_template_pair_feat( def build_template_pair_feat(
batch, batch: Dict[str, Any],
min_bin, max_bin, no_bins, min_bin: float,
use_unit_vector=False, max_bin: float,
eps=1e-20, inf=1e8 no_bins: int,
use_unit_vector: bool = False,
eps: float = 1e-20,
inf: float = 1e8,
): ):
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5) # Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"] tpb = batch["template_pseudo_beta"]
dgram = torch.sum( dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf)
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]] to_concat = [dgram, template_mask_2d[..., None]]
...@@ -137,9 +134,7 @@ def build_template_pair_feat( ...@@ -137,9 +134,7 @@ def build_template_pair_feat(
) )
) )
to_concat.append( to_concat.append(
aatype_one_hot[..., None, :].expand( aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)
*aatype_one_hot.shape[:-2], -1, n_res, -1
)
) )
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
...@@ -152,19 +147,17 @@ def build_template_pair_feat( ...@@ -152,19 +147,17 @@ def build_template_pair_feat(
points = rigids.get_trans()[..., None, :, :] points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points) rigid_vec = rigids[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1)) inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"] t_aa_masks = batch["template_all_atom_mask"]
template_mask = ( template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
)
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None] unit_vector = rigid_vec * inv_distance_scalar[..., None]
if(not use_unit_vector): if not use_unit_vector:
unit_vector = unit_vector * 0. unit_vector = unit_vector * 0.0
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
...@@ -175,7 +168,7 @@ def build_template_pair_feat( ...@@ -175,7 +168,7 @@ def build_template_pair_feat(
return act return act
def build_extra_msa_feat(batch): def build_extra_msa_feat(batch: Dict[str, Any]) -> torch.Tensor:
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [ msa_feat = [
msa_1hot, msa_1hot,
...@@ -186,11 +179,11 @@ def build_extra_msa_feat(batch): ...@@ -186,11 +179,11 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames( def torsion_angles_to_frames(
r: Rigid, r: Union[Rigid3Array, Rigid],
alpha: torch.Tensor, alpha: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
rrgdf: torch.Tensor, rrgdf: torch.Tensor,
): ) -> Union[Rigid, Rigid3Array]:
# [*, N, 8, 4, 4] # [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...] default_4x4 = rrgdf[aatype, ...]
...@@ -203,9 +196,7 @@ def torsion_angles_to_frames( ...@@ -203,9 +196,7 @@ def torsion_angles_to_frames(
bb_rot[..., 1] = 1 bb_rot[..., 1] = 1
# [*, N, 8, 2] # [*, N, 8, 2]
alpha = torch.cat( alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
# [*, N, 8, 3, 3] # [*, N, 8, 3, 3]
# Produces rotation matrices of the form: # Produces rotation matrices of the form:
...@@ -216,16 +207,26 @@ def torsion_angles_to_frames( ...@@ -216,16 +207,26 @@ def torsion_angles_to_frames(
# ] # ]
# This follows the original code rather than the supplement, which uses # This follows the original code rather than the supplement, which uses
# different indices. # different indices.
if type(r) == Rigid3Array:
all_rots = alpha.new_zeros(default_r.shape + (3, 3))
elif type(r) == Rigid:
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
else:
raise TypeError(f"Wrong type of Rigid: {type(r)}")
all_rots[..., 0, 0] = 1 all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1] all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:] = alpha
if type(r) == Rigid3Array:
all_rots = Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
elif type(r) == Rigid:
all_rots = Rigid(Rotation(rot_mats=all_rots), None) all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots) all_frames = default_r.compose(all_rots)
else:
raise TypeError(f"Wrong type of Rigid: {type(r)}")
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
...@@ -236,6 +237,17 @@ def torsion_angles_to_frames( ...@@ -236,6 +237,17 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
if type(all_frames) == Rigid3Array:
all_frames_to_bb = Rigid3Array.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
elif type(all_frames) == Rigid:
all_frames_to_bb = Rigid.cat( all_frames_to_bb = Rigid.cat(
[ [
all_frames[..., :5], all_frames[..., :5],
...@@ -252,13 +264,13 @@ def torsion_angles_to_frames( ...@@ -252,13 +264,13 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
r: Rigid, r: Union[Rigid3Array, Rigid],
aatype: torch.Tensor, aatype: torch.Tensor,
default_frames, default_frames: torch.Tensor,
group_idx, group_idx: torch.Tensor,
atom_mask, atom_mask: torch.Tensor,
lit_positions, lit_positions: torch.Tensor,
): ) -> torch.Tensor:
# [*, N, 14, 4, 4] # [*, N, 14, 4, 4]
default_4x4 = default_frames[aatype, ...] default_4x4 = default_frames[aatype, ...]
...@@ -266,21 +278,30 @@ def frames_and_literature_positions_to_atom14_pos( ...@@ -266,21 +278,30 @@ def frames_and_literature_positions_to_atom14_pos(
group_mask = group_idx[aatype, ...] group_mask = group_idx[aatype, ...]
# [*, N, 14, 8] # [*, N, 14, 8]
if type(r) == Rigid3Array:
group_mask = nn.functional.one_hot(
group_mask.long(),
num_classes=default_frames.shape[-3],
)
elif type(r) == Rigid:
group_mask = nn.functional.one_hot( group_mask = nn.functional.one_hot(
group_mask, group_mask,
num_classes=default_frames.shape[-3], num_classes=default_frames.shape[-3],
) )
else:
raise TypeError(f"Wrong type of Rigid: {type(r)}")
# [*, N, 14, 8] # [*, N, 14, 8]
t_atoms_to_global = r[..., None, :] * group_mask t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14] # [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn( t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
lambda x: torch.sum(x, dim=-1)
)
# [*, N, 14, 1] # [*, N, 14, 1]
if type(r) == Rigid:
atom_mask = atom_mask[aatype, ...].unsqueeze(-1) atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
elif type(r) == Rigid3Array:
atom_mask = atom_mask[aatype, ...]
# [*, N, 14, 3] # [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...] lit_positions = lit_positions[aatype, ...]
......
...@@ -39,6 +39,12 @@ class ParamType(Enum): ...@@ -39,6 +39,12 @@ class ParamType(Enum):
LinearWeightOPM = partial( LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
) )
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1)
if len(w.shape) == 1
else w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(lambda w: w.reshape(-1))
Other = partial(lambda w: w) Other = partial(lambda w: w)
def __init__(self, fn): def __init__(self, fn):
...@@ -121,29 +127,30 @@ def assign(translation_dict, orig_weights): ...@@ -121,29 +127,30 @@ def assign(translation_dict, orig_weights):
print(weights[0].shape) print(weights[0].shape)
raise raise
def get_translation_dict(model, is_multimer: bool = False):
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
####################### #######################
# Some templates # Some templates
####################### #######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l)) LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (Param(l, param_type=ParamType.LinearBiasMultimer))
LinearParams = lambda l: { LinearParams = lambda l: {
"weights": LinearWeight(l.weight), "weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias), "bias": LinearBias(l.bias),
} }
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: { LayerNormParams = lambda l: {
"scale": Param(l.weight), "scale": Param(l.weight),
"offset": Param(l.bias), "offset": Param(l.bias),
...@@ -239,7 +246,43 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -239,7 +246,43 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"q_scalar": LinearParams(ipa.linear_q), "q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv), "kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points), "q_point_local": LinearParams(ipa.linear_q_points),
# New style IPA param
# "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points), "kv_point_local": LinearParams(ipa.linear_kv_points),
# New style IPA param
# "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_v.weight,
),
},
"q_point_projection": PointProjectionParams(ipa.linear_q_points),
"k_point_projection": PointProjectionParams(ipa.linear_k_points),
"v_point_projection": PointProjectionParams(ipa.linear_v_points),
"trainable_point_weights": Param( "trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other param=ipa.head_weights, param_type=ParamType.Other
), ),
...@@ -278,31 +321,26 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -278,31 +321,26 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params = MSAColAttParams(b.msa_att_col) msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = { d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams( "msa_row_attention_with_pair_bias": MSAAttPairBiasParams(b.msa_att_row),
b.msa_att_row
),
col_att_name: msa_col_att_params, col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition), "msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean": "outer_product_mean": OuterProductMeanParams(b.core.outer_product_mean),
OuterProductMeanParams(b.core.outer_product_mean), "triangle_multiplication_outgoing": TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_outgoing": "triangle_multiplication_incoming": TriMulInParams(b.core.tri_mul_in),
TriMulOutParams(b.core.tri_mul_out), "triangle_attention_starting_node": TriAttParams(b.core.tri_att_start),
"triangle_multiplication_incoming": "triangle_attention_ending_node": TriAttParams(b.core.tri_att_end),
TriMulInParams(b.core.tri_mul_in), "pair_transition": PairTransitionParams(b.core.pair_transition),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
} }
return d return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True) ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: { def FoldIterationParams(sm):
"invariant_point_attention": IPAParams(sm.ipa), d = {
"invariant_point_attention": IPAParamsMultimer(sm.ipa)
if is_multimer
else IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1), "transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2), "transition_1": LinearParams(sm.transition.layers[0].linear_2),
...@@ -320,14 +358,17 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -320,14 +358,17 @@ def import_jax_weights_(model, npz_path, version="model_1"):
}, },
} }
if is_multimer:
d.pop("affine_update")
d["quat_rigid"] = {"rigid": LinearParams(sm.bb_update.linear)}
return d
############################ ############################
# translations dict overflow # translations dict overflow
############################ ############################
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks = model.template_pair_stack.blocks tps_blocks_params = stacked([TemplatePairBlockParams(b) for b in tps_blocks])
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
ems_blocks = model.extra_msa_stack.blocks ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
...@@ -335,6 +376,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -335,6 +376,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
evo_blocks = model.evoformer.blocks evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks]) evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
if not is_multimer:
translations = { translations = {
"evoformer": { "evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m), "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
...@@ -348,32 +390,30 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -348,32 +390,30 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"prev_pair_norm": LayerNormParams( "prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z model.recycling_embedder.layer_norm_z
), ),
"pair_activiations": LinearParams( "pair_activiations": LinearParams(model.input_embedder.linear_relpos),
model.input_embedder.linear_relpos
),
"template_embedding": { "template_embedding": {
"single_template_embedding": { "single_template_embedding": {
"embedding2d": LinearParams( "embedding2d": LinearParams(
model.template_pair_embedder.linear model.template_embedder.template_pair_embedder.linear
), ),
"template_pair_stack": { "template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params, "__layer_stack_no_state": tps_blocks_params,
}, },
"output_layer_norm": LayerNormParams( "output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm model.template_embedder.template_pair_stack.layer_norm
), ),
}, },
"attention": AttentionParams(model.template_pointwise_att.mha), "attention": AttentionParams(
}, model.template_embedder.template_pointwise_att.mha
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
), ),
},
"extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
"extra_msa_stack": ems_blocks_params, "extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams( "template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1 model.template_embedder.template_angle_embedder.linear_1
), ),
"template_projection": LinearParams( "template_projection": LinearParams(
model.template_angle_embedder.linear_2 model.template_embedder.template_angle_embedder.linear_2
), ),
"evoformer_iteration": evo_blocks_params, "evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear), "single_activations": LinearParams(model.evoformer.linear),
...@@ -382,18 +422,106 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -382,18 +422,106 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"single_layer_norm": LayerNormParams( "single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s model.structure_module.layer_norm_s
), ),
"initial_projection": LinearParams( "initial_projection": LinearParams(model.structure_module.linear_in),
model.structure_module.linear_in "pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(model.aux_heads.plddt.layer_norm),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(model.aux_heads.experimentally_resolved.linear),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
else:
temp_embedder = model.template_embedder
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
), ),
"pair_layer_norm": LayerNormParams( "template_pair_embedding_6": LinearParamsMultimer(
model.structure_module.layer_norm_z temp_embedder.template_pair_embedder.z_linear
), ),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(temp_embedder.linear_t),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
"extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(model.structure_module.linear_in),
"pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
"fold_iteration": FoldIterationParams(model.structure_module), "fold_iteration": FoldIterationParams(model.structure_module),
}, },
"predicted_lddt_head": { "predicted_lddt_head": {
"input_layer_norm": LayerNormParams( "input_layer_norm": LayerNormParams(model.aux_heads.plddt.layer_norm),
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1), "act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2), "act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3), "logits": LinearParams(model.aux_heads.plddt.linear_3),
...@@ -402,15 +530,22 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -402,15 +530,22 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"half_logits": LinearParams(model.aux_heads.distogram.linear), "half_logits": LinearParams(model.aux_heads.distogram.linear),
}, },
"experimentally_resolved_head": { "experimentally_resolved_head": {
"logits": LinearParams( "logits": LinearParams(model.aux_heads.experimentally_resolved.linear),
model.aux_heads.experimentally_resolved.linear
),
}, },
"masked_msa_head": { "masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear), "logits": LinearParams(model.aux_heads.masked_msa.linear),
}, },
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, is_multimer=("multimer" in version))
no_templ = [ no_templ = [
"model_3", "model_3",
"model_4", "model_4",
......
...@@ -266,7 +266,7 @@ def inject_extraMsaBlock(model): ...@@ -266,7 +266,7 @@ def inject_extraMsaBlock(model):
def inject_templatePairBlock(model): def inject_templatePairBlock(model):
with torch.no_grad(): with torch.no_grad():
target_module = model.template_pair_stack.blocks target_module = model.template_embedder.template_pair_stack.blocks
fastfold_blocks = nn.ModuleList() fastfold_blocks = nn.ModuleList()
for block_id, ori_block in enumerate(target_module): for block_id, ori_block in enumerate(target_module):
c_t = ori_block.c_t c_t = ori_block.c_t
...@@ -294,7 +294,7 @@ def inject_templatePairBlock(model): ...@@ -294,7 +294,7 @@ def inject_templatePairBlock(model):
fastfold_block.eval() fastfold_block.eval()
fastfold_blocks.append(fastfold_block) fastfold_blocks.append(fastfold_block)
model.template_pair_stack.blocks = fastfold_blocks model.template_embedder.template_pair_stack.blocks = fastfold_blocks
def inject_fastnn(model): def inject_fastnn(model):
......
...@@ -159,7 +159,7 @@ def main(args): ...@@ -159,7 +159,7 @@ def main(args):
print("Generating features...") print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
if global_is_multimer: if global_is_multimer:
print("multimer mode") print("running in multimer mode...")
feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb")) feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
else: else:
if (args.use_precomputed_alignments is None): if (args.use_precomputed_alignments is None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment