Unverified Commit 9c0e7519 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

Multimer (#57)

* add import _weight

* add struct mudule
parent ea7a6584
......@@ -17,6 +17,7 @@ from functools import partial
import torch
import torch.nn as nn
from fastfold.data import data_transforms_multimer
from fastfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
......@@ -27,8 +28,7 @@ from fastfold.utils.feats import (
from fastfold.model.nn.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
TemplateEmbedder,
ExtraMSAEmbedder,
)
from fastfold.model.nn.embedders_multimer import TemplateEmbedderMultimer, InputEmbedderMultimer
......@@ -36,10 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from fastfold.model.nn.heads import AuxiliaryHeads
import fastfold.common.residue_constants as residue_constants
from fastfold.model.nn.structure_module import StructureModule
from fastfold.model.nn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.model.loss import (
compute_plddt,
)
......@@ -81,24 +77,13 @@ class AlphaFold(nn.Module):
self.input_embedder = InputEmbedder(
**config["input_embedder"],
)
self.template_angle_embedder = TemplateAngleEmbedder(
**template_config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"],
self.template_embedder = TemplateEmbedder(
template_config,
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
......@@ -210,11 +195,15 @@ class AlphaFold(nn.Module):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
m, z = (
self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
if not self.globals.is_multimer
else self.input_embedder(feats)
)
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
......@@ -236,9 +225,8 @@ class AlphaFold(nn.Module):
requires_grad=False,
)
x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
x_prev, _ = pseudo_beta_fn(feats["aatype"], x_prev, None)
x_prev = x_prev.to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
......@@ -270,40 +258,72 @@ class AlphaFold(nn.Module):
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
if self.globals.is_multimer:
asym_id = feats["asym_id"]
multichain_mask_2d = asym_id[..., None] == asym_id[..., None, :]
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles:
if(
self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled)
):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
[m, template_embeds["template_single_embedding"]],
dim=-3
)
# [*, S, N]
if(not self.globals.is_multimer):
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
else:
msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]],
dim=-2,
)
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
if(self.globals.is_multimer):
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else:
extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
extra_msa_feat,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans,
......@@ -353,14 +373,14 @@ class AlphaFold(nn.Module):
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None
self.template_embedder.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = (
self.template_embedder.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
......
......@@ -17,9 +17,20 @@ import torch
import torch.nn as nn
from typing import Tuple, Dict
from functools import partial
from fastfold.utils import all_atom_multimer
from fastfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot
from fastfold.model.nn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils import geometry
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
"""
......@@ -221,6 +232,97 @@ class RecyclingEmbedder(nn.Module):
return m_update, z_update
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
class TemplateAngleEmbedder(nn.Module):
"""
......
......@@ -84,7 +84,6 @@ class MSATransition(nn.Module):
no_batch_dims=len(m.shape[:-2]),
)
def forward(
self,
m: torch.Tensor,
......@@ -101,10 +100,12 @@ class MSATransition(nn.Module):
m:
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if mask is None:
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, 1]
mask = mask.unsqueeze(-1)
m = self.layer_norm(m)
......@@ -132,9 +133,10 @@ class EvoformerBlockCore(nn.Module):
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
is_multimer: bool = False,
):
super(EvoformerBlockCore, self).__init__()
self.is_multimer = is_multimer
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
......@@ -261,6 +263,12 @@ class EvoformerBlock(nn.Module):
eps=eps,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
......
......@@ -16,7 +16,7 @@
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
from fastfold.model.nn.primitives import Linear, LayerNorm, ipa_point_weights_init_
from fastfold.common.residue_constants import (
......@@ -73,7 +73,9 @@ class AngleResnet(nn.Module):
Implements Algorithm 20, lines 11-14
"""
def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
def __init__(
self, c_in: int, c_hidden: int, no_blocks: int, no_angles: int, epsilon: float
):
"""
Args:
c_in:
......@@ -145,7 +147,7 @@ class AngleResnet(nn.Module):
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s ** 2, dim=-1, keepdim=True),
torch.sum(s**2, dim=-1, keepdim=True),
min=self.eps,
)
)
......@@ -153,6 +155,7 @@ class AngleResnet(nn.Module):
return unnormalized_s, s
class PointProjection(nn.Module):
def __init__(
self,
......@@ -491,7 +494,7 @@ class BackboneUpdate(nn.Module):
Implements part of Algorithm 23.
"""
def __init__(self, c_s):
def __init__(self, c_s: int):
"""
Args:
c_s:
......@@ -517,7 +520,7 @@ class BackboneUpdate(nn.Module):
class StructureModuleTransitionLayer(nn.Module):
def __init__(self, c):
def __init__(self, c: int):
super(StructureModuleTransitionLayer, self).__init__()
self.c = c
......@@ -528,7 +531,7 @@ class StructureModuleTransitionLayer(nn.Module):
self.relu = nn.ReLU()
def forward(self, s):
def forward(self, s: torch.Tensor):
s_initial = s
s = self.linear_1(s)
s = self.relu(s)
......@@ -542,7 +545,7 @@ class StructureModuleTransitionLayer(nn.Module):
class StructureModuleTransition(nn.Module):
def __init__(self, c, num_layers, dropout_rate):
def __init__(self, c: int, num_layers: int, dropout_rate: float):
super(StructureModuleTransition, self).__init__()
self.c = c
......@@ -557,7 +560,7 @@ class StructureModuleTransition(nn.Module):
self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = LayerNorm(self.c)
def forward(self, s):
def forward(self, s: torch.Tensor) -> torch.Tensor:
for l in self.layers:
s = l(s)
......@@ -570,22 +573,22 @@ class StructureModuleTransition(nn.Module):
class StructureModule(nn.Module):
def __init__(
self,
c_s,
c_z,
c_ipa,
c_resnet,
no_heads_ipa,
no_qk_points,
no_v_points,
dropout_rate,
no_blocks,
no_transition_layers,
no_resnet_blocks,
no_angles,
trans_scale_factor,
epsilon,
inf,
is_multimer=False,
c_s: int,
c_z: int,
c_ipa: int,
c_resnet: int,
no_heads_ipa: int,
no_qk_points: int,
no_v_points: int,
dropout_rate: float,
no_blocks: int,
no_transition_layers: int,
no_resnet_blocks: int,
no_angles: int,
trans_scale_factor: float,
epsilon: float,
inf: float,
is_multimer: bool = False,
**kwargs,
):
"""
......@@ -621,6 +624,8 @@ class StructureModule(nn.Module):
Small number used in angle resnet normalization
inf:
Large number used for attention masking
is_multimer:
whether running under multimer mode
"""
super(StructureModule, self).__init__()
......@@ -673,6 +678,9 @@ class StructureModule(nn.Module):
self.dropout_rate,
)
if is_multimer:
self.bb_update = QuatRigid(self.c_s, full_quat=False)
else:
self.bb_update = BackboneUpdate(self.c_s)
self.angle_resnet = AngleResnet(
......@@ -683,13 +691,13 @@ class StructureModule(nn.Module):
self.epsilon,
)
def forward(
def _forward_monomer(
self,
s,
z,
aatype,
mask=None,
):
s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
"""
Args:
s:
......@@ -785,7 +793,103 @@ class StructureModule(nn.Module):
return outputs
def _init_residue_constants(self, float_dtype, device):
def _forward_multimer(
self,
s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
# [*, N, C_s]
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(z)
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
# [*, N]
rigids = Rigid3Array.identity(
s.shape[:-1],
s.device,
)
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids @ self.bb_update(s)
# [*, N, 7, 2]
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
rigids.scale_translation(self.trans_scale_factor),
angles,
aatype,
)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
aatype,
)
preds = {
"frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz.to_tensor(),
}
outputs.append(preds)
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if self.is_multimer:
outputs = self._forward_multimer(s, z, aatype, mask)
else:
outputs = self._forward_monomer(s, z, aatype, mask)
return outputs
def _init_residue_constants(self, float_dtype: torch.dtype, device: torch.device):
if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
......@@ -814,17 +918,24 @@ class StructureModule(nn.Module):
requires_grad=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
def torsion_angles_to_frames(
self, r: Union[Rigid, Rigid3Array], alpha: torch.Tensor, f
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(
self, r, f # [*, N, 8] # [*, N]
self, r: Union[Rigid, Rigid3Array], f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
if type(r) == Rigid:
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
elif type(r) == Rigid3Array:
self._init_residue_constants(r.dtype, r.device)
else:
raise ValueError("Unknown rigid type")
return frames_and_literature_positions_to_atom14_pos(
r,
f,
......
......@@ -18,10 +18,12 @@ import math
import numpy as np
import torch
import torch.nn as nn
from typing import Dict
from typing import Any, Dict, Optional, Tuple, Union
from fastfold.common import protein
import fastfold.common.residue_constants as rc
from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from fastfold.utils.geometry.rotation_matrix import Rot3Array
from fastfold.utils.rigid_utils import Rotation, Rigid
from fastfold.utils.tensor_utils import (
batched_gather,
......@@ -36,7 +38,7 @@ def dgram_from_positions(
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
) -> torch.Tensor:
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
......@@ -46,8 +48,9 @@ def dgram_from_positions(
return dgram
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
def pseudo_beta_fn(
aatype, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
......@@ -65,10 +68,10 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
return pseudo_beta, None
def atom14_to_atom37(atom14, batch):
def atom14_to_atom37(atom14, batch: Dict[str, Any]):
atom37_data = batched_gather(
atom14,
batch["residx_atom37_to_atom14"],
......@@ -81,19 +84,15 @@ def atom14_to_atom37(atom14, batch):
return atom37_data
def build_template_angle_feat(template_feats):
def build_template_angle_feat(template_feats: Dict[str, Any]) -> torch.Tensor:
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats[
"template_alt_torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape(
*torsion_angles_sin_cos.shape[:-2], 14
),
torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
alt_torsion_angles_sin_cos.reshape(
*alt_torsion_angles_sin_cos.shape[:-2], 14
),
......@@ -106,22 +105,20 @@ def build_template_angle_feat(template_feats):
def build_template_pair_feat(
batch,
min_bin, max_bin, no_bins,
use_unit_vector=False,
eps=1e-20, inf=1e8
batch: Dict[str, Any],
min_bin: float,
max_bin: float,
no_bins: int,
use_unit_vector: bool = False,
eps: float = 1e-20,
inf: float = 1e8,
):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"]
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf)
to_concat = [dgram, template_mask_2d[..., None]]
......@@ -137,9 +134,7 @@ def build_template_pair_feat(
)
)
to_concat.append(
aatype_one_hot[..., None, :].expand(
*aatype_one_hot.shape[:-2], -1, n_res, -1
)
aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)
)
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
......@@ -152,19 +147,17 @@ def build_template_pair_feat(
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
template_mask = (
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
)
template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None]
if(not use_unit_vector):
unit_vector = unit_vector * 0.
if not use_unit_vector:
unit_vector = unit_vector * 0.0
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
......@@ -175,7 +168,7 @@ def build_template_pair_feat(
return act
def build_extra_msa_feat(batch):
def build_extra_msa_feat(batch: Dict[str, Any]) -> torch.Tensor:
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [
msa_1hot,
......@@ -186,11 +179,11 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames(
r: Rigid,
r: Union[Rigid3Array, Rigid],
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
) -> Union[Rigid, Rigid3Array]:
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
......@@ -203,9 +196,7 @@ def torsion_angles_to_frames(
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
......@@ -216,16 +207,26 @@ def torsion_angles_to_frames(
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
if type(r) == Rigid3Array:
all_rots = alpha.new_zeros(default_r.shape + (3, 3))
elif type(r) == Rigid:
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
else:
raise TypeError(f"Wrong type of Rigid: {type(r)}")
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
if type(r) == Rigid3Array:
all_rots = Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
elif type(r) == Rigid:
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
else:
raise TypeError(f"Wrong type of Rigid: {type(r)}")
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
......@@ -236,6 +237,17 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
if type(all_frames) == Rigid3Array:
all_frames_to_bb = Rigid3Array.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
elif type(all_frames) == Rigid:
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
......@@ -252,13 +264,13 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos(
r: Rigid,
r: Union[Rigid3Array, Rigid],
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
default_frames: torch.Tensor,
group_idx: torch.Tensor,
atom_mask: torch.Tensor,
lit_positions: torch.Tensor,
) -> torch.Tensor:
# [*, N, 14, 4, 4]
default_4x4 = default_frames[aatype, ...]
......@@ -266,21 +278,30 @@ def frames_and_literature_positions_to_atom14_pos(
group_mask = group_idx[aatype, ...]
# [*, N, 14, 8]
if type(r) == Rigid3Array:
group_mask = nn.functional.one_hot(
group_mask.long(),
num_classes=default_frames.shape[-3],
)
elif type(r) == Rigid:
group_mask = nn.functional.one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)
else:
raise TypeError(f"Wrong type of Rigid: {type(r)}")
# [*, N, 14, 8]
t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
lambda x: torch.sum(x, dim=-1)
)
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
# [*, N, 14, 1]
if type(r) == Rigid:
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
elif type(r) == Rigid3Array:
atom_mask = atom_mask[aatype, ...]
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
......
......@@ -39,6 +39,12 @@ class ParamType(Enum):
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1)
if len(w.shape) == 1
else w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(lambda w: w.reshape(-1))
Other = partial(lambda w: w)
def __init__(self, fn):
......@@ -121,29 +127,30 @@ def assign(translation_dict, orig_weights):
print(weights[0].shape)
raise
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
def get_translation_dict(model, is_multimer: bool = False):
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (Param(l, param_type=ParamType.LinearBiasMultimer))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
......@@ -239,7 +246,43 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
# New style IPA param
# "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points),
# New style IPA param
# "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_v.weight,
),
},
"q_point_projection": PointProjectionParams(ipa.linear_q_points),
"k_point_projection": PointProjectionParams(ipa.linear_k_points),
"v_point_projection": PointProjectionParams(ipa.linear_v_points),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
......@@ -278,31 +321,26 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
b.msa_att_row
),
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(b.msa_att_row),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
"outer_product_mean": OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_outgoing": TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node": TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.core.tri_att_end),
"pair_transition": PairTransitionParams(b.core.pair_transition),
}
return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: {
"invariant_point_attention": IPAParams(sm.ipa),
def FoldIterationParams(sm):
d = {
"invariant_point_attention": IPAParamsMultimer(sm.ipa)
if is_multimer
else IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
......@@ -320,14 +358,17 @@ def import_jax_weights_(model, npz_path, version="model_1"):
},
}
if is_multimer:
d.pop("affine_update")
d["quat_rigid"] = {"rigid": LinearParams(sm.bb_update.linear)}
return d
############################
# translations dict overflow
############################
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks_params = stacked([TemplatePairBlockParams(b) for b in tps_blocks])
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
......@@ -335,6 +376,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
if not is_multimer:
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
......@@ -348,32 +390,30 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"pair_activiations": LinearParams(
model.input_embedder.linear_relpos
),
"pair_activiations": LinearParams(model.input_embedder.linear_relpos),
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_pair_embedder.linear
model.template_embedder.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
model.template_embedder.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
"attention": AttentionParams(
model.template_embedder.template_pointwise_att.mha
),
},
"extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
model.template_embedder.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
model.template_embedder.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
......@@ -382,18 +422,106 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
"initial_projection": LinearParams(model.structure_module.linear_in),
"pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(model.aux_heads.plddt.layer_norm),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(model.aux_heads.experimentally_resolved.linear),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
else:
temp_embedder = model.template_embedder
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(temp_embedder.linear_t),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
"extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(model.structure_module.linear_in),
"pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"input_layer_norm": LayerNormParams(model.aux_heads.plddt.layer_norm),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
......@@ -402,15 +530,22 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
"logits": LinearParams(model.aux_heads.experimentally_resolved.linear),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, is_multimer=("multimer" in version))
no_templ = [
"model_3",
"model_4",
......
......@@ -266,7 +266,7 @@ def inject_extraMsaBlock(model):
def inject_templatePairBlock(model):
with torch.no_grad():
target_module = model.template_pair_stack.blocks
target_module = model.template_embedder.template_pair_stack.blocks
fastfold_blocks = nn.ModuleList()
for block_id, ori_block in enumerate(target_module):
c_t = ori_block.c_t
......@@ -294,7 +294,7 @@ def inject_templatePairBlock(model):
fastfold_block.eval()
fastfold_blocks.append(fastfold_block)
model.template_pair_stack.blocks = fastfold_blocks
model.template_embedder.template_pair_stack.blocks = fastfold_blocks
def inject_fastnn(model):
......
......@@ -159,7 +159,7 @@ def main(args):
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if global_is_multimer:
print("multimer mode")
print("running in multimer mode...")
feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
else:
if (args.use_precomputed_alignments is None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment