Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
import torch
import torch.nn as nn
from .common import (
residual,
)
from .featurization import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
build_template_pair_feat_v2,
atom14_to_atom37,
)
from .embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
ExtraMSAEmbedder,
)
from .evoformer import EvoformerStack, ExtraMSAStack
from .auxillary_heads import AuxiliaryHeads
from unifold.data import residue_constants
from .structure_module import StructureModule
from .template import (
TemplatePairStack,
TemplatePointwiseAttention,
TemplateProjection,
)
from unicore.utils import (
tensor_tree_map,
)
from .attentions import (
gen_msa_attn_mask,
gen_tri_attn_mask,
)
class AlphaFold(nn.Module):
def __init__(self, config):
super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa
self.input_embedder = InputEmbedder(
**config["input_embedder"],
use_chain_relative=config.is_multimer,
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
if config.template.enabled:
self.template_angle_embedder = TemplateAngleEmbedder(
**template_config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
)
else:
self.template_pair_stack = None
self.enable_template_pointwise_attention = template_config[
"template_pointwise_attention"
].enabled
if self.enable_template_pointwise_attention:
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"],
)
else:
self.template_proj = TemplateProjection(
**template_config["template_pointwise_attention"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
**config["evoformer_stack"],
)
self.structure_module = StructureModule(
**config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
config["heads"],
)
self.config = config
self.dtype = torch.float
self.inf = self.globals.inf
if self.globals.alphafold_original_mode:
self.alphafold_original_mode()
def __make_input_float__(self):
self.input_embedder = self.input_embedder.float()
self.recycling_embedder = self.recycling_embedder.float()
def half(self):
super().half()
if (not getattr(self, "inference", False)):
self.__make_input_float__()
self.dtype = torch.half
return self
def bfloat16(self):
super().bfloat16()
if (not getattr(self, "inference", False)):
self.__make_input_float__()
self.dtype = torch.bfloat16
return self
def alphafold_original_mode(self):
def set_alphafold_original_mode(module):
if hasattr(module, "apply_alphafold_original_mode"):
module.apply_alphafold_original_mode()
if hasattr(module, "act"):
module.act = nn.ReLU()
self.apply(set_alphafold_original_mode)
def inference_mode(self):
def set_inference_mode(module):
setattr(module, "inference", True)
self.apply(set_inference_mode)
def __convert_input_dtype__(self, batch):
for key in batch:
# only convert features with mask
if batch[key].dtype != self.dtype and "mask" in key:
batch[key] = batch[key].type(self.dtype)
return batch
def embed_templates_pair_core(self, batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d):
if self.config.template.template_pair_embedder.v2_feature:
t = build_template_pair_feat_v2(
batch,
inf=self.config.template.inf,
eps=self.config.template.eps,
multichain_mask_2d=multichain_mask_2d,
**self.config.template.distogram,
)
num_template = t[0].shape[-4]
single_templates = [
self.template_pair_embedder([x[..., ti, :, :, :] for x in t], z)
for ti in range(num_template)
]
else:
t = build_template_pair_feat(
batch,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
)
single_templates = [
self.template_pair_embedder(x, z)
for x in torch.unbind(t, dim=templ_dim)
]
t = self.template_pair_stack(
single_templates,
pair_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
templ_dim=templ_dim,
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
return_mean=not self.enable_template_pointwise_attention,
)
return t
def embed_templates_pair(
self, batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim
):
if self.config.template.template_pair_embedder.v2_feature and "asym_id" in batch:
multichain_mask_2d = (
batch["asym_id"][..., :, None] == batch["asym_id"][..., None, :]
)
multichain_mask_2d = multichain_mask_2d.unsqueeze(0)
else:
multichain_mask_2d = None
if self.training or self.enable_template_pointwise_attention:
t = self.embed_templates_pair_core(batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d)
if self.enable_template_pointwise_attention:
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"],
chunk_size=self.globals.chunk_size,
)
t_mask = torch.sum(batch["template_mask"], dim=-1, keepdims=True) > 0
t_mask = t_mask[..., None, None].type(t.dtype)
t *= t_mask
else:
t = self.template_proj(t, z)
else:
template_aatype_shape = batch["template_aatype"].shape
# template_aatype is either [n_template, n_res] or [1, n_template_, n_res]
batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0
n_templ = batch["template_aatype"].shape[batch_templ_dim]
if n_templ <= 0:
t = None
else:
template_batch = { k: v for k, v in batch.items() if k.startswith("template_") }
def embed_one_template(i):
def slice_template_tensor(t):
s = [slice(None) for _ in t.shape]
s[batch_templ_dim] = slice(i, i + 1)
return t[s]
template_feats = tensor_tree_map(
slice_template_tensor,
template_batch,
)
t = self.embed_templates_pair_core(template_feats, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d)
return t
t = embed_one_template(0)
# iterate templates one by one
for i in range(1, n_templ):
t += embed_one_template(i)
t /= n_templ
t = self.template_proj(t, z)
return t
def embed_templates_angle(self, batch):
template_angle_feat, template_angle_mask = build_template_angle_feat(
batch, v2_feature=self.config.template.template_pair_embedder.v2_feature
)
t = self.template_angle_embedder(template_angle_feat)
return t, template_angle_mask
def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
batch_dims = feats["target_feat"].shape[:-2]
n = feats["target_feat"].shape[-2]
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
m, z = self.input_embedder(
feats["target_feat"],
feats["msa_feat"],
)
if m_1_prev is None:
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.d_msa),
requires_grad=False,
)
if z_prev is None:
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.d_pair),
requires_grad=False,
)
if x_prev is None:
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
z += self.recycling_embedder.recyle_pos(x_prev)
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
)
m[..., 0, :, :] += m_1_prev_emb
z += z_prev_emb
z += self.input_embedder.relpos_emb(
feats["residue_index"].long(),
feats.get("sym_id", None),
feats.get("asym_id", None),
feats.get("entity_id", None),
feats.get("num_sym", None),
)
m = m.type(self.dtype)
z = z.type(self.dtype)
tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(pair_mask, self.inf)
if self.config.template.enabled:
template_mask = feats["template_mask"]
if torch.any(template_mask):
z = residual(
z,
self.embed_templates_pair(
feats,
z,
pair_mask,
tri_start_attn_mask,
tri_end_attn_mask,
templ_dim=-4,
),
self.training,
)
if self.config.extra_msa.enabled:
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
extra_msa_row_mask = gen_msa_attn_mask(
feats["extra_msa_mask"],
inf=self.inf,
gen_col_mask=False,
)
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"],
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
pair_mask=pair_mask,
msa_row_attn_mask=extra_msa_row_mask,
msa_col_attn_mask=None,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
)
if self.config.template.embed_angles:
template_1d_feat, template_1d_mask = self.embed_templates_angle(feats)
m = torch.cat([m, template_1d_feat], dim=-3)
msa_mask = torch.cat([feats["msa_mask"], template_1d_mask], dim=-2)
msa_row_mask, msa_col_mask = gen_msa_attn_mask(
msa_mask,
inf=self.inf,
)
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_row_attn_mask=msa_row_mask,
msa_col_attn_mask=msa_col_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
)
return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb
def iteration_evoformer_structure_module(
self, batch, m_1_prev, z_prev, x_prev, cycle_no, num_recycling, num_ensembles=1
):
z, s = 0, 0
n_seq = batch["msa_feat"].shape[-3]
assert num_ensembles >= 1
for ensemble_no in range(num_ensembles):
idx = cycle_no * num_ensembles + ensemble_no
fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...]
feats = tensor_tree_map(fetch_cur_batch, batch)
m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer(
feats, m_1_prev, z_prev, x_prev
)
z += z0
s += s0
del z0, s0
if num_ensembles > 1:
z /= float(num_ensembles)
s /= float(num_ensembles)
outputs = {}
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# norm loss
if (not getattr(self, "inference", False)) and num_recycling == (cycle_no + 1):
delta_msa = m
delta_msa[..., 0, :, :] = delta_msa[..., 0, :, :] - m_1_prev_emb.detach()
delta_pair = z - z_prev_emb.detach()
outputs["delta_msa"] = delta_msa
outputs["delta_pair"] = delta_pair
outputs["msa_norm_mask"] = msa_mask
outputs["sm"] = self.structure_module(
s,
z,
feats["aatype"],
mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["pred_frame_tensor"] = outputs["sm"]["frames"][-1]
# use float32 for numerical stability
if (not getattr(self, "inference", False)):
m_1_prev = m[..., 0, :, :].float()
z_prev = z.float()
x_prev = outputs["final_atom_positions"].float()
else:
m_1_prev = m[..., 0, :, :]
z_prev = z
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
def forward(self, batch):
m_1_prev = batch.get("m_1_prev", None)
z_prev = batch.get("z_prev", None)
x_prev = batch.get("x_prev", None)
is_grad_enabled = torch.is_grad_enabled()
num_iters = int(batch["num_recycling_iters"]) + 1
num_ensembles = int(batch["msa_mask"].shape[0]) // num_iters
if self.training:
# don't use ensemble during training
assert num_ensembles == 1
# convert dtypes in batch
batch = self.__convert_input_dtype__(batch)
for cycle_no in range(num_iters):
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
(
outputs,
m_1_prev,
z_prev,
x_prev,
) = self.iteration_evoformer_structure_module(
batch,
m_1_prev,
z_prev,
x_prev,
cycle_no=cycle_no,
num_recycling=num_iters,
num_ensembles=num_ensembles,
)
if not is_final_iter:
del outputs
if "asym_id" in batch:
outputs["asym_id"] = batch["asym_id"][0, ...]
outputs.update(self.aux_heads(outputs))
return outputs
from functools import partialmethod
from typing import Optional, List
import torch
import torch.nn as nn
from .common import Linear, chunk_layer
from unicore.utils import (
permute_final_dims,
)
from unicore.modules import (
softmax_dropout,
LayerNorm,
)
def gen_attn_mask(mask, neg_inf):
assert neg_inf < -1e4
attn_mask = torch.zeros_like(mask)
attn_mask[mask == 0] = neg_inf
return attn_mask
class Attention(nn.Module):
def __init__(
self,
q_dim: int,
k_dim: int,
v_dim: int,
head_dim: int,
num_heads: int,
gating: bool = True,
):
super(Attention, self).__init__()
self.num_heads = num_heads
total_dim = head_dim * self.num_heads
self.gating = gating
self.linear_q = Linear(q_dim, total_dim, bias=False, init="glorot")
self.linear_k = Linear(k_dim, total_dim, bias=False, init="glorot")
self.linear_v = Linear(v_dim, total_dim, bias=False, init="glorot")
self.linear_o = Linear(total_dim, q_dim, init="final")
self.linear_g = None
if self.gating:
self.linear_g = Linear(q_dim, total_dim, init="gating")
# precompute the 1/sqrt(head_dim)
self.norm = head_dim**-0.5
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
g = None
if self.linear_g is not None:
# gating, use raw query input
g = self.linear_g(q)
q = self.linear_q(q)
q *= self.norm
k = self.linear_k(k)
v = self.linear_v(v)
q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous()
k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous()
v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3)
attn = torch.matmul(q, k.transpose(-1, -2))
del q, k
attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
o = torch.matmul(attn, v)
del attn, v
o = o.transpose(-2, -3).contiguous()
o = o.view(*o.shape[:-2], -1)
if g is not None:
o = torch.sigmoid(g) * o
# merge heads
o = nn.functional.linear(o, self.linear_o.weight)
return o
def get_output_bias(self):
return self.linear_o.bias
class GlobalAttention(nn.Module):
def __init__(self, input_dim, head_dim, num_heads, inf, eps):
super(GlobalAttention, self).__init__()
self.num_heads = num_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(
input_dim, head_dim * num_heads, bias=False, init="glorot"
)
self.linear_k = Linear(input_dim, head_dim, bias=False, init="glorot")
self.linear_v = Linear(input_dim, head_dim, bias=False, init="glorot")
self.linear_g = Linear(input_dim, head_dim * num_heads, init="gating")
self.linear_o = Linear(head_dim * num_heads, input_dim, init="final")
self.sigmoid = nn.Sigmoid()
# precompute the 1/sqrt(head_dim)
self.norm = head_dim**-0.5
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# gating
g = self.sigmoid(self.linear_g(x))
k = self.linear_k(x)
v = self.linear_v(x)
q = torch.sum(x * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1, keepdims=True) + self.eps
)
q = self.linear_q(q)
q *= self.norm
q = q.view(q.shape[:-1] + (self.num_heads, -1))
attn = torch.matmul(q, k.transpose(-1, -2))
del q, k
attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :]
attn = softmax_dropout(attn, 0, self.training, mask=attn_mask)
o = torch.matmul(
attn,
v,
)
del attn, v
g = g.view(g.shape[:-1] + (self.num_heads, -1))
o = o.unsqueeze(-3) * g
del g
# merge heads
o = o.reshape(o.shape[:-2] + (-1,))
return self.linear_o(o)
def gen_msa_attn_mask(mask, inf, gen_col_mask=True):
row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
if gen_col_mask:
col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, None, :]
return row_mask, col_mask
else:
return row_mask
class MSAAttention(nn.Module):
def __init__(
self,
d_in,
d_hid,
num_heads,
pair_bias=False,
d_pair=None,
):
super(MSAAttention, self).__init__()
self.pair_bias = pair_bias
self.layer_norm_m = LayerNorm(d_in)
self.layer_norm_z = None
self.linear_z = None
if self.pair_bias:
self.layer_norm_z = LayerNorm(d_pair)
self.linear_z = Linear(d_pair, num_heads, bias=False, init="normal")
self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
@torch.jit.ignore
def _chunk(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
chunk_size: int = None,
) -> torch.Tensor:
return chunk_layer(
self._attn_forward,
{"m": m, "mask": mask, "bias": bias},
chunk_size=chunk_size,
num_batch_dims=len(m.shape[:-2]),
)
@torch.jit.ignore
def _attn_chunk_forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = 2560,
) -> torch.Tensor:
m = self.layer_norm_m(m)
num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size
outputs = []
for i in range(num_chunk):
chunk_start = i * chunk_size
chunk_end = min(m.shape[-3], chunk_start + chunk_size)
cur_m = m[..., chunk_start:chunk_end, :, :]
cur_mask = (
mask[..., chunk_start:chunk_end, :, :, :] if mask is not None else None
)
outputs.append(
self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)
)
return torch.concat(outputs, dim=-3)
def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
m = self.layer_norm_m(m)
return self.mha(q=m, k=m, v=m, mask=mask, bias=bias)
def forward(
self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
bias = None
if self.pair_bias:
z = self.layer_norm_z(z)
bias = (
permute_final_dims(self.linear_z(z), (2, 0, 1))
.unsqueeze(-4)
.contiguous()
)
if chunk_size is not None:
m = self._chunk(m, attn_mask, bias, chunk_size)
else:
attn_chunk_size = 2560
if m.shape[-3] <= attn_chunk_size:
m = self._attn_forward(m, attn_mask, bias)
else:
# reduce the peak memory cost in extra_msa_stack
return self._attn_chunk_forward(
m, attn_mask, bias, chunk_size=attn_chunk_size
)
return m
def get_output_bias(self):
return self.mha.get_output_bias()
class MSARowAttentionWithPairBias(MSAAttention):
def __init__(self, d_msa, d_pair, d_hid, num_heads):
super(MSARowAttentionWithPairBias, self).__init__(
d_msa,
d_hid,
num_heads,
pair_bias=True,
d_pair=d_pair,
)
class MSAColumnAttention(MSAAttention):
def __init__(self, d_msa, d_hid, num_heads):
super(MSAColumnAttention, self).__init__(
d_in=d_msa,
d_hid=d_hid,
num_heads=num_heads,
pair_bias=False,
d_pair=None,
)
def forward(
self,
m: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
m = m.transpose(-2, -3)
m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size)
m = m.transpose(-2, -3)
return m
class MSAColumnGlobalAttention(nn.Module):
def __init__(
self,
d_in,
d_hid,
num_heads,
inf=1e9,
eps=1e-10,
):
super(MSAColumnGlobalAttention, self).__init__()
self.layer_norm_m = LayerNorm(d_in)
self.global_attention = GlobalAttention(
d_in,
d_hid,
num_heads,
inf=inf,
eps=eps,
)
@torch.jit.ignore
def _chunk(
self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._attn_forward,
{"m": m, "mask": mask},
chunk_size=chunk_size,
num_batch_dims=len(m.shape[:-2]),
)
def _attn_forward(self, m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask=mask)
def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
m = m.transpose(-2, -3)
mask = mask.transpose(-1, -2)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
m = self._attn_forward(m, mask=mask)
m = m.transpose(-2, -3)
return m
def gen_tri_attn_mask(mask, inf):
start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, None, :]
return start_mask, end_mask
class TriangleAttention(nn.Module):
def __init__(
self,
d_in,
d_hid,
num_heads,
starting,
):
super(TriangleAttention, self).__init__()
self.starting = starting
self.layer_norm = LayerNorm(d_in)
self.linear = Linear(d_in, num_heads, bias=False, init="normal")
self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
@torch.jit.ignore
def _chunk(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
chunk_size: int = None,
) -> torch.Tensor:
return chunk_layer(
self.mha,
{"q": x, "k": x, "v": x, "mask": mask, "bias": bias},
chunk_size=chunk_size,
num_batch_dims=len(x.shape[:-2]),
)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
if not self.starting:
x = x.transpose(-2, -3)
x = self.layer_norm(x)
triangle_bias = (
permute_final_dims(self.linear(x), (2, 0, 1)).unsqueeze(-4).contiguous()
)
if chunk_size is not None:
x = self._chunk(x, attn_mask, triangle_bias, chunk_size)
else:
x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias)
if not self.starting:
x = x.transpose(-2, -3)
return x
def get_output_bias(self):
return self.mha.get_output_bias()
class TriangleAttentionStarting(TriangleAttention):
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
class TriangleAttentionEnding(TriangleAttention):
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
import torch.nn as nn
from typing import Dict
from unicore.modules import LayerNorm
from .common import Linear
from .confidence import predicted_lddt, predicted_tm_score, predicted_aligned_error
class AuxiliaryHeads(nn.Module):
def __init__(self, config):
super(AuxiliaryHeads, self).__init__()
self.plddt = PredictedLDDTHead(
**config["plddt"],
)
self.distogram = DistogramHead(
**config["distogram"],
)
self.masked_msa = MaskedMSAHead(
**config["masked_msa"],
)
if config.experimentally_resolved.enabled:
self.experimentally_resolved = ExperimentallyResolvedHead(
**config["experimentally_resolved"],
)
if config.pae.enabled:
self.pae = PredictedAlignedErrorHead(
**config.pae,
)
self.config = config
def forward(self, outputs):
aux_out = {}
plddt_logits = self.plddt(outputs["sm"]["single"])
aux_out["plddt_logits"] = plddt_logits
aux_out["plddt"] = predicted_lddt(plddt_logits.detach())
distogram_logits = self.distogram(outputs["pair"])
aux_out["distogram_logits"] = distogram_logits
masked_msa_logits = self.masked_msa(outputs["msa"])
aux_out["masked_msa_logits"] = masked_msa_logits
if self.config.experimentally_resolved.enabled:
exp_res_logits = self.experimentally_resolved(outputs["single"])
aux_out["experimentally_resolved_logits"] = exp_res_logits
if self.config.pae.enabled:
pae_logits = self.pae(outputs["pair"])
aux_out["pae_logits"] = pae_logits
pae_logits = pae_logits.detach()
aux_out.update(
predicted_aligned_error(
pae_logits,
**self.config.pae,
)
)
aux_out["ptm"] = predicted_tm_score(
pae_logits, interface=False, **self.config.pae
)
iptm_weight = self.config.pae.get("iptm_weight", 0.0)
if iptm_weight > 0.0:
aux_out["iptm"] = predicted_tm_score(
pae_logits,
interface=True,
asym_id=outputs["asym_id"],
**self.config.pae,
)
aux_out["iptm+ptm"] = (
iptm_weight * aux_out["iptm"] + (1.0 - iptm_weight) * aux_out["ptm"]
)
return aux_out
class PredictedLDDTHead(nn.Module):
def __init__(self, num_bins, d_in, d_hid):
super(PredictedLDDTHead, self).__init__()
self.num_bins = num_bins
self.d_in = d_in
self.d_hid = d_hid
self.layer_norm = LayerNorm(self.d_in)
self.linear_1 = Linear(self.d_in, self.d_hid, init="relu")
self.linear_2 = Linear(self.d_hid, self.d_hid, init="relu")
self.act = nn.GELU()
self.linear_3 = Linear(self.d_hid, self.num_bins, init="final")
def forward(self, s):
s = self.layer_norm(s)
s = self.linear_1(s)
s = self.act(s)
s = self.linear_2(s)
s = self.act(s)
s = self.linear_3(s)
return s
class EnhancedHeadBase(nn.Module):
def __init__(self, d_in, d_out, disable_enhance_head):
super(EnhancedHeadBase, self).__init__()
if disable_enhance_head:
self.layer_norm = None
self.linear_in = None
else:
self.layer_norm = LayerNorm(d_in)
self.linear_in = Linear(d_in, d_in, init="relu")
self.act = nn.GELU()
self.linear = Linear(d_in, d_out, init="final")
def apply_alphafold_original_mode(self):
self.layer_norm = None
self.linear_in = None
def forward(self, x):
if self.layer_norm is not None:
x = self.layer_norm(x)
x = self.act(self.linear_in(x))
logits = self.linear(x)
return logits
class DistogramHead(EnhancedHeadBase):
def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
super(DistogramHead, self).__init__(
d_in=d_pair,
d_out=num_bins,
disable_enhance_head=disable_enhance_head,
)
def forward(self, x):
logits = super().forward(x)
logits = logits + logits.transpose(-2, -3)
return logits
class PredictedAlignedErrorHead(EnhancedHeadBase):
def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
super(PredictedAlignedErrorHead, self).__init__(
d_in=d_pair,
d_out=num_bins,
disable_enhance_head=disable_enhance_head,
)
class MaskedMSAHead(EnhancedHeadBase):
def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs):
super(MaskedMSAHead, self).__init__(
d_in=d_msa,
d_out=d_out,
disable_enhance_head=disable_enhance_head,
)
class ExperimentallyResolvedHead(EnhancedHeadBase):
def __init__(self, d_single, d_out, disable_enhance_head, **kwargs):
super(ExperimentallyResolvedHead, self).__init__(
d_in=d_single,
d_out=d_out,
disable_enhance_head=disable_enhance_head,
)
from functools import partial
from typing import Optional, Any, Callable, List, Dict, Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from unicore.modules import LayerNorm
from unicore.utils import tensor_tree_map
class Linear(nn.Linear):
def __init__(
self,
d_in: int,
d_out: int,
bias: bool = True,
init: str = "default",
):
super(Linear, self).__init__(d_in, d_out, bias=bias)
self.use_bias = bias
if self.use_bias:
with torch.no_grad():
self.bias.fill_(0)
if init == "default":
self._trunc_normal_init(1.0)
elif init == "relu":
self._trunc_normal_init(2.0)
elif init == "glorot":
self._glorot_uniform_init()
elif init == "gating":
self._zero_init(self.use_bias)
elif init == "normal":
self._normal_init()
elif init == "final":
self._zero_init(False)
else:
raise ValueError("Invalid init method.")
def _trunc_normal_init(self, scale=1.0):
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978
_, fan_in = self.weight.shape
scale = scale / max(1, fan_in)
std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR
nn.init.trunc_normal_(self.weight, mean=0.0, std=std)
def _glorot_uniform_init(self):
nn.init.xavier_uniform_(self.weight, gain=1)
def _zero_init(self, use_bias=True):
with torch.no_grad():
self.weight.fill_(0.0)
if use_bias:
with torch.no_grad():
self.bias.fill_(1.0)
def _normal_init(self):
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="linear")
class Transition(nn.Module):
def __init__(self, d_in, n):
super(Transition, self).__init__()
self.d_in = d_in
self.n = n
self.layer_norm = LayerNorm(self.d_in)
self.linear_1 = Linear(self.d_in, self.n * self.d_in, init="relu")
self.act = nn.GELU()
self.linear_2 = Linear(self.n * self.d_in, d_in, init="final")
def _transition(self, x):
x = self.layer_norm(x)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x
@torch.jit.ignore
def _chunk(
self,
x: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"x": x},
chunk_size=chunk_size,
num_batch_dims=len(x.shape[:-2]),
)
def forward(
self,
x: torch.Tensor,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
if chunk_size is not None:
x = self._chunk(x, chunk_size)
else:
x = self._transition(x=x)
return x
class OuterProductMean(nn.Module):
def __init__(self, d_msa, d_pair, d_hid, eps=1e-3):
super(OuterProductMean, self).__init__()
self.d_msa = d_msa
self.d_pair = d_pair
self.d_hid = d_hid
self.eps = eps
self.layer_norm = LayerNorm(d_msa)
self.linear_1 = Linear(d_msa, d_hid)
self.linear_2 = Linear(d_msa, d_hid)
self.linear_out = Linear(d_hid**2, d_pair, init="relu")
self.act = nn.GELU()
self.linear_z = Linear(self.d_pair, self.d_pair, init="final")
self.layer_norm_out = LayerNorm(self.d_pair)
def _opm(self, a, b):
outer = torch.einsum("...bac,...dae->...bdce", a, b)
outer = outer.reshape(outer.shape[:-2] + (-1,))
outer = self.linear_out(outer)
return outer
@torch.jit.ignore
def _chunk(self, a: torch.Tensor, b: torch.Tensor, chunk_size: int) -> torch.Tensor:
a = a.reshape((-1,) + a.shape[-3:])
b = b.reshape((-1,) + b.shape[-3:])
out = []
# TODO: optimize this
for a_prime, b_prime in zip(a, b):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{"a": a_prime},
chunk_size=chunk_size,
num_batch_dims=1,
)
out.append(outer)
if len(out) == 1:
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer
def apply_alphafold_original_mode(self):
self.linear_z = None
self.layer_norm_out = None
def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
m = self.layer_norm(m)
mask = mask.unsqueeze(-1)
if self.layer_norm_out is not None:
# for numerical stability
mask = mask * (mask.size(-2) ** -0.5)
a = self.linear_1(m)
b = self.linear_2(m)
if self.training:
a = a * mask
b = b * mask
else:
a *= mask
b *= mask
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
if chunk_size is not None:
z = self._chunk(a, b, chunk_size)
else:
z = self._opm(a, b)
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
z /= self.eps + norm
if self.layer_norm_out is not None:
z = self.act(z)
z = self.layer_norm_out(z)
z = self.linear_z(z)
return z
def residual(residual, x, training):
if training:
return x + residual
else:
residual += x
return residual
@torch.jit.script
def fused_bias_dropout_add(
x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
dropmask: torch.Tensor,
prob: float,
) -> torch.Tensor:
return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual
@torch.jit.script
def fused_bias_dropout_add_inference(
x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
residual += bias + x
return residual
def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob, training):
bias = module.get_output_bias()
if training:
shape = list(x.shape)
shape[dropout_shared_dim] = 1
with torch.no_grad():
mask = x.new_ones(shape)
return fused_bias_dropout_add(x, bias, residual, mask, prob)
else:
return fused_bias_dropout_add_inference(x, bias, residual)
@torch.jit.script
def fused_bias_gated_dropout_add(
x: torch.Tensor,
bias: torch.Tensor,
g: torch.Tensor,
g_bias: torch.Tensor,
residual: torch.Tensor,
dropout_mask: torch.Tensor,
prob: float,
) -> torch.Tensor:
return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout(
dropout_mask,
p=prob,
training=True,
) + residual
def tri_mul_residual(
module,
residual,
outputs,
dropout_shared_dim,
prob,
training,
block_size,
):
if training:
x, g = outputs
bias, g_bias = module.get_output_bias()
shape = list(x.shape)
shape[dropout_shared_dim] = 1
with torch.no_grad():
mask = x.new_ones(shape)
return fused_bias_gated_dropout_add(
x,
bias,
g,
g_bias,
residual,
mask,
prob,
)
elif block_size is None:
x, g = outputs
bias, g_bias = module.get_output_bias()
residual += (torch.sigmoid(g + g_bias) * (x + bias))
return residual
else:
# gated is not used here
residual += outputs
return residual
class SimpleModuleList(nn.ModuleList):
def __repr__(self):
return str(len(self)) + " X ...\n" + self[0].__repr__()
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
num_batch_dims: int,
) -> Any:
# TODO: support inplace add to output
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
def _dict_get_shapes(input):
shapes = []
if type(input) is torch.Tensor:
shapes.append(input.shape)
elif type(input) is dict:
for v in input.values():
shapes.extend(_dict_get_shapes(v))
elif isinstance(input, Iterable):
for v in input:
shapes.extend(_dict_get_shapes(v))
else:
raise ValueError("Not supported")
return shapes
inputs = {k: v for k, v in inputs.items() if v is not None}
initial_dims = [shape[:num_batch_dims] for shape in _dict_get_shapes(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size
def _flat_inputs(t):
t = t.view(-1, *t.shape[num_batch_dims:])
assert (
t.shape[0] == flat_batch_dim or t.shape[0] == 1
), "batch dimension must be 1 or equal to the flat batch dimension"
return t
flat_inputs = tensor_tree_map(_flat_inputs, inputs)
out = None
for i in range(num_chunks):
chunk_start = i * chunk_size
chunk_end = min((i + 1) * chunk_size, flat_batch_dim)
def select_chunk(t):
if t.shape[0] == 1:
return t[0:1]
else:
return t[chunk_start:chunk_end]
chunkes = tensor_tree_map(select_chunk, flat_inputs)
output_chunk = layer(**chunkes)
if out is None:
out = tensor_tree_map(
lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk
)
out_type = type(output_chunk)
if out_type is tuple:
for x, y in zip(out, output_chunk):
x[chunk_start:chunk_end] = y
elif out_type is torch.Tensor:
out[chunk_start:chunk_end] = output_chunk
else:
raise ValueError("Not supported")
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
import torch
from typing import Dict, Optional, Tuple
def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor:
"""Computes per-residue pLDDT from logits.
Args:
logits: [num_res, num_bins] output from the PredictedLDDTHead.
Returns:
plddt: [num_res] per-residue pLDDT.
"""
num_bins = plddt_logits.shape[-1]
bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1)
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=0.5 * bin_width, end=1.0, step=bin_width, device=plddt_logits.device
)
plddt = torch.sum(
bin_probs * bounds.view(*((1,) * len(bin_probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return plddt
def compute_bin_values(breaks: torch.Tensor):
"""Gets the bin centers from the bin edges.
Args:
breaks: [num_bins - 1] the error bin edges.
Returns:
bin_centers: [num_bins] the error bin centers.
"""
step = breaks[1] - breaks[0]
bin_values = breaks + step / 2
bin_values = torch.cat(
[bin_values, (bin_values[-1] + step).unsqueeze(-1)], dim=0
)
return bin_values
def compute_predicted_aligned_error(
bin_edges: torch.Tensor,
bin_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculates expected aligned distance errors for every pair of residues.
Args:
alignment_confidence_breaks: [num_bins - 1] the error bin edges.
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
probs for each error bin, for each pair of residues.
Returns:
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
bin_values = compute_bin_values(bin_edges)
return torch.sum(bin_probs * bin_values, dim=-1)
def predicted_aligned_error(
pae_logits: torch.Tensor,
max_bin: int = 31,
num_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins - 1] the error bin edges.
Returns:
aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1)
bin_edges = torch.linspace(0, max_bin, steps=(num_bins - 1), device=pae_logits.device)
predicted_aligned_error = compute_predicted_aligned_error(
bin_edges=bin_edges,
bin_probs=bin_probs,
)
return {
"aligned_error_probs_per_bin": bin_probs,
"predicted_aligned_error": predicted_aligned_error,
}
def predicted_tm_score(
pae_logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
num_bins: int = 64,
eps: float = 1e-8,
asym_id: Optional[torch.Tensor] = None,
interface: bool = False,
**kwargs,
) -> torch.Tensor:
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation, i.e. when interface=True.
interface: If True, interface predicted TM score is computed.
Returns:
ptm_score: The predicted TM alignment or the predicted iTM score.
"""
pae_logits = pae_logits.float()
if residue_weights is None:
residue_weights = pae_logits.new_ones(pae_logits.shape[:-2])
breaks = torch.linspace(0, max_bin, steps=(num_bins - 1), device=pae_logits.device)
def tm_kernal(nres):
clipped_n = max(nres, 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3.0) - 1.8
return lambda x: 1.0 / (1.0 + (x / d0) ** 2)
def rmsd_kernal(eps): # leave for compute pRMS
return lambda x: 1. / (x + eps)
bin_centers = compute_bin_values(breaks)
probs = torch.nn.functional.softmax(pae_logits, dim=-1)
tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers)
# tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
# rmsd_per_bin = rmsd_kernal()(bin_centers)
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape)
if interface:
assert asym_id is not None, "must provide asym_id for iptm calculation."
pair_mask *= asym_id[..., :, None] != asym_id[..., None, :]
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (
residue_weights[None, :] * residue_weights[:, None]
)
normed_residue_mask = pair_residue_weights / (
eps + pair_residue_weights.sum(dim=-1, keepdim=True)
)
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
ret = per_alignment.gather(
dim=-1, index=weighted.max(dim=-1, keepdim=True).indices
).squeeze(dim=-1)
return ret
import torch
import torch.nn as nn
from typing import Optional, Tuple
from unicore.utils import one_hot
from .common import Linear, residual
from .common import SimpleModuleList
from unicore.modules import LayerNorm
class InputEmbedder(nn.Module):
def __init__(
self,
tf_dim: int,
msa_dim: int,
d_pair: int,
d_msa: int,
relpos_k: int,
use_chain_relative: bool = False,
max_relative_chain: Optional[int] = None,
**kwargs,
):
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.d_pair = d_pair
self.d_msa = d_msa
self.linear_tf_z_i = Linear(tf_dim, d_pair)
self.linear_tf_z_j = Linear(tf_dim, d_pair)
self.linear_tf_m = Linear(tf_dim, d_msa)
self.linear_msa_m = Linear(msa_dim, d_msa)
# RPE stuff
self.relpos_k = relpos_k
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if not self.use_chain_relative:
self.num_bins = 2 * self.relpos_k + 1
else:
self.num_bins = 2 * self.relpos_k + 2
self.num_bins += 1 # entity id
self.num_bins += 2 * max_relative_chain + 2
self.linear_relpos = Linear(self.num_bins, d_pair)
def _relpos_indices(
self,
res_id: torch.Tensor,
sym_id: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
entity_id: Optional[torch.Tensor] = None,
):
max_rel_res = self.relpos_k
rp = res_id[..., None] - res_id[..., None, :]
rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res
if not self.use_chain_relative:
return rp
else:
asym_id_same = asym_id[..., :, None] == asym_id[..., None, :]
rp[~asym_id_same] = 2 * max_rel_res + 1
entity_id_same = entity_id[..., :, None] == entity_id[..., None, :]
rp_entity_id = entity_id_same.type(rp.dtype)[..., None]
rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain
)
clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1
return rp, rp_entity_id, clipped_rel_chain
def relpos_emb(
self,
res_id: torch.Tensor,
sym_id: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
entity_id: Optional[torch.Tensor] = None,
num_sym: Optional[torch.Tensor] = None,
):
dtype = self.linear_relpos.weight.dtype
if not self.use_chain_relative:
rp = self._relpos_indices(res_id=res_id)
return self.linear_relpos(
one_hot(rp, num_classes=self.num_bins, dtype=dtype)
)
else:
rp, rp_entity_id, rp_rel_chain = self._relpos_indices(
res_id=res_id, sym_id=sym_id, asym_id=asym_id, entity_id=entity_id
)
rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype)
rp_entity_id = rp_entity_id.type(dtype)
rp_rel_chain = one_hot(
rp_rel_chain, num_classes=(2 * self.max_relative_chain + 2), dtype=dtype
)
return self.linear_relpos(
torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1)
)
def forward(
self,
tf: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# [*, N_res, d_pair]
if self.tf_dim == 21:
# multimer use 21 target dim
tf = tf[..., 1:]
# convert type if necessary
tf = tf.type(self.linear_tf_z_i.weight.dtype)
msa = msa.type(self.linear_tf_z_i.weight.dtype)
n_clust = msa.shape[-3]
msa_emb = self.linear_msa_m(msa)
# target_feat (aatype) into msa representation
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1))) # expand -3 dim
)
msa_emb += tf_m
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
def __init__(
self,
d_msa: int,
d_pair: int,
min_bin: float,
max_bin: float,
num_bins: int,
inf: float = 1e8,
**kwargs,
):
super(RecyclingEmbedder, self).__init__()
self.d_msa = d_msa
self.d_pair = d_pair
self.min_bin = min_bin
self.max_bin = max_bin
self.num_bins = num_bins
self.inf = inf
self.squared_bins = None
self.linear = Linear(self.num_bins, self.d_pair)
self.layer_norm_m = LayerNorm(self.d_msa)
self.layer_norm_z = LayerNorm(self.d_pair)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
m_update = self.layer_norm_m(m)
z_update = self.layer_norm_z(z)
return m_update, z_update
def recyle_pos(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.squared_bins is None:
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.num_bins,
dtype=torch.float if self.training else x.dtype,
device=x.device,
requires_grad=False,
)
self.squared_bins = bins**2
upper = torch.cat(
[self.squared_bins[1:], self.squared_bins.new_tensor([self.inf])], dim=-1
)
if self.training:
x = x.float()
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
d = ((d > self.squared_bins) * (d < upper)).type(self.linear.weight.dtype)
d = self.linear(d)
return d
class TemplateAngleEmbedder(nn.Module):
def __init__(
self,
d_in: int,
d_out: int,
**kwargs,
):
super(TemplateAngleEmbedder, self).__init__()
self.d_out = d_out
self.d_in = d_in
self.linear_1 = Linear(self.d_in, self.d_out, init="relu")
self.act = nn.GELU()
self.linear_2 = Linear(self.d_out, self.d_out, init="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x.type(self.linear_1.weight.dtype))
x = self.act(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
def __init__(
self,
d_in: int,
v2_d_in: list,
d_out: int,
d_pair: int,
v2_feature: bool = False,
**kwargs,
):
super(TemplatePairEmbedder, self).__init__()
self.d_out = d_out
self.v2_feature = v2_feature
if self.v2_feature:
self.d_in = v2_d_in
self.linear = SimpleModuleList()
for d_in in self.d_in:
self.linear.append(Linear(d_in, self.d_out, init="relu"))
self.z_layer_norm = LayerNorm(d_pair)
self.z_linear = Linear(d_pair, self.d_out, init="relu")
else:
self.d_in = d_in
self.linear = Linear(self.d_in, self.d_out, init="relu")
def forward(
self,
x,
z,
) -> torch.Tensor:
if not self.v2_feature:
x = self.linear(x.type(self.linear.weight.dtype))
return x
else:
dtype = self.z_linear.weight.dtype
t = self.linear[0](x[0].type(dtype))
for i, s in enumerate(x[1:]):
t = residual(t, self.linear[i + 1](s.type(dtype)), self.training)
t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training)
return t
class ExtraMSAEmbedder(nn.Module):
def __init__(
self,
d_in: int,
d_out: int,
**kwargs,
):
super(ExtraMSAEmbedder, self).__init__()
self.d_in = d_in
self.d_out = d_out
self.linear = Linear(self.d_in, self.d_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x.type(self.linear.weight.dtype))
import torch
import torch.nn as nn
from typing import Tuple, Optional
from functools import partial
from .common import (
Linear,
Transition,
OuterProductMean,
SimpleModuleList,
residual,
bias_dropout_residual,
tri_mul_residual,
)
from .attentions import (
MSARowAttentionWithPairBias,
MSAColumnAttention,
MSAColumnGlobalAttention,
TriangleAttentionStarting,
TriangleAttentionEnding,
)
from .triangle_multiplication import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from unicore.utils import checkpoint_sequential
class EvoformerIteration(nn.Module):
def __init__(
self,
d_msa: int,
d_pair: int,
d_hid_msa_att: int,
d_hid_opm: int,
d_hid_mul: int,
d_hid_pair_att: int,
num_heads_msa: int,
num_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
outer_product_mean_first: bool,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
):
super(EvoformerIteration, self).__init__()
self._is_extra_msa_stack = _is_extra_msa_stack
self.outer_product_mean_first = outer_product_mean_first
self.msa_att_row = MSARowAttentionWithPairBias(
d_msa=d_msa,
d_pair=d_pair,
d_hid=d_hid_msa_att,
num_heads=num_heads_msa,
)
if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
d_in=d_msa,
d_hid=d_hid_msa_att,
num_heads=num_heads_msa,
inf=inf,
eps=eps,
)
else:
self.msa_att_col = MSAColumnAttention(
d_msa,
d_hid_msa_att,
num_heads_msa,
)
self.msa_transition = Transition(
d_in=d_msa,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
d_msa,
d_pair,
d_hid_opm,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
d_pair,
d_hid_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
d_pair,
d_hid_mul,
)
self.tri_att_start = TriangleAttentionStarting(
d_pair,
d_hid_pair_att,
num_heads_pair,
)
self.tri_att_end = TriangleAttentionEnding(
d_pair,
d_hid_pair_att,
num_heads_pair,
)
self.pair_transition = Transition(
d_in=d_pair,
n=transition_n,
)
self.row_dropout_share_dim = -3
self.col_dropout_share_dim = -2
self.msa_dropout = msa_dropout
self.pair_dropout = pair_dropout
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
msa_row_attn_mask: torch.Tensor,
msa_col_attn_mask: Optional[torch.Tensor],
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: Optional[int] = None,
block_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.outer_product_mean_first:
z = residual(
z, self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size),
self.training
)
m = bias_dropout_residual(
self.msa_att_row,
m,
self.msa_att_row(
m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size
),
self.row_dropout_share_dim,
self.msa_dropout,
self.training,
)
if self._is_extra_msa_stack:
m = residual(
m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size),
self.training
)
else:
m = bias_dropout_residual(
self.msa_att_col,
m,
self.msa_att_col(m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.msa_dropout,
self.training,
)
m = residual(
m, self.msa_transition(m, chunk_size=chunk_size),
self.training
)
if not self.outer_product_mean_first:
z = residual(
z, self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size),
self.training
)
z = tri_mul_residual(
self.tri_mul_out,
z,
self.tri_mul_out(z, mask=pair_mask, block_size=block_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
block_size=block_size,
)
z = tri_mul_residual(
self.tri_mul_in,
z,
self.tri_mul_in(z, mask=pair_mask, block_size=block_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
block_size=block_size,
)
z = bias_dropout_residual(
self.tri_att_start,
z,
self.tri_att_start(z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
)
z = bias_dropout_residual(
self.tri_att_end,
z,
self.tri_att_end(z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.pair_dropout,
self.training,
)
z = residual(
z, self.pair_transition(z, chunk_size=chunk_size),
self.training
)
return m, z
class EvoformerStack(nn.Module):
def __init__(
self,
d_msa: int,
d_pair: int,
d_hid_msa_att: int,
d_hid_opm: int,
d_hid_mul: int,
d_hid_pair_att: int,
d_single: int,
num_heads_msa: int,
num_heads_pair: int,
num_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
outer_product_mean_first: bool,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
**kwargs,
):
super(EvoformerStack, self).__init__()
self._is_extra_msa_stack = _is_extra_msa_stack
self.blocks = SimpleModuleList()
for _ in range(num_blocks):
self.blocks.append(
EvoformerIteration(
d_msa=d_msa,
d_pair=d_pair,
d_hid_msa_att=d_hid_msa_att,
d_hid_opm=d_hid_opm,
d_hid_mul=d_hid_mul,
d_hid_pair_att=d_hid_pair_att,
num_heads_msa=num_heads_msa,
num_heads_pair=num_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
outer_product_mean_first=outer_product_mean_first,
inf=inf,
eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
)
)
if not self._is_extra_msa_stack:
self.linear = Linear(d_msa, d_single)
else:
self.linear = None
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
msa_row_attn_mask: torch.Tensor,
msa_col_attn_mask: torch.Tensor,
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: int,
block_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_row_attn_mask=msa_row_attn_mask,
msa_col_attn_mask=msa_col_attn_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size
)
for b in self.blocks
]
m, z = checkpoint_sequential(
blocks,
input=(m, z),
)
s = None
if not self._is_extra_msa_stack:
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)
return m, z, s
class ExtraMSAStack(EvoformerStack):
def __init__(
self,
d_msa: int,
d_pair: int,
d_hid_msa_att: int,
d_hid_opm: int,
d_hid_mul: int,
d_hid_pair_att: int,
num_heads_msa: int,
num_heads_pair: int,
num_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
outer_product_mean_first: bool,
inf: float,
eps: float,
**kwargs,
):
super(ExtraMSAStack, self).__init__(
d_msa=d_msa,
d_pair=d_pair,
d_hid_msa_att=d_hid_msa_att,
d_hid_opm=d_hid_opm,
d_hid_mul=d_hid_mul,
d_hid_pair_att=d_hid_pair_att,
d_single=None,
num_heads_msa=num_heads_msa,
num_heads_pair=num_heads_pair,
num_blocks=num_blocks,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
outer_product_mean_first=outer_product_mean_first,
inf=inf,
eps=eps,
_is_extra_msa_stack=True,
)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
msa_row_attn_mask: torch.Tensor = None,
msa_col_attn_mask: torch.Tensor = None,
tri_start_attn_mask: torch.Tensor = None,
tri_end_attn_mask: torch.Tensor = None,
chunk_size: int = None,
block_size: int = None,
) -> torch.Tensor:
_, z, _ = super().forward(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_row_attn_mask=msa_row_attn_mask,
msa_col_attn_mask=msa_col_attn_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size
)
return z
import torch
import torch.nn as nn
from typing import Dict
from unifold.data import residue_constants as rc
from .frame import Frame
from unicore.utils import (
batched_gather,
one_hot,
)
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly,
all_atom_masks[..., ca_idx],
all_atom_masks[..., cb_idx],
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
batch["residx_atom37_to_atom14"],
dim=-2,
num_batch_dims=len(atom14.shape[:-2]),
)
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
return atom37_data
def build_template_angle_feat(template_feats, v2_feature=False):
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
if not v2_feature:
alt_torsion_angles_sin_cos = template_feats[
"template_alt_torsion_angles_sin_cos"
]
template_angle_feat = torch.cat(
[
one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
alt_torsion_angles_sin_cos.reshape(
*alt_torsion_angles_sin_cos.shape[:-2], 14
),
torsion_angles_mask,
],
dim=-1,
)
template_angle_mask = torsion_angles_mask[..., 2]
else:
chi_mask = torsion_angles_mask[..., 3:]
chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask
chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask
template_angle_feat = torch.cat(
[
one_hot(template_aatype, 22),
chi_angles_sin,
chi_angles_cos,
chi_mask,
],
dim=-1,
)
template_angle_mask = chi_mask[..., 0]
return template_angle_feat, template_angle_mask
def build_template_pair_feat(
batch,
min_bin,
max_bin,
num_bins,
eps=1e-20,
inf=1e8,
):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
tpb = batch["template_pseudo_beta"]
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"],
rc.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
to_concat.append(
aatype_one_hot[..., None, :, :].expand(
*aatype_one_hot.shape[:-2], n_res, -1, -1
)
)
to_concat.append(
aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)
)
to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3))
to_concat.append(template_mask_2d[..., None])
act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None]
return act
def build_template_pair_feat_v2(
batch,
min_bin,
max_bin,
num_bins,
multichain_mask_2d=None,
eps=1e-20,
inf=1e8,
):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
if multichain_mask_2d is not None:
template_mask_2d *= multichain_mask_2d
tpb = batch["template_pseudo_beta"]
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
dgram *= template_mask_2d[..., None]
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = one_hot(
batch["template_aatype"],
rc.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
to_concat.append(
aatype_one_hot[..., None, :, :].expand(
*aatype_one_hot.shape[:-2], n_res, -1, -1
)
)
to_concat.append(
aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)
)
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
rigids = Frame.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps,
)
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[..., None, :]
if multichain_mask_2d is not None:
backbone_mask_2d *= multichain_mask_2d
inv_distance_scalar = inv_distance_scalar * backbone_mask_2d
unit_vector_data = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1))
to_concat.append(backbone_mask_2d[..., None])
return to_concat
def build_extra_msa_feat(batch):
msa_1hot = one_hot(batch["extra_msa"], 23)
msa_feat = [
msa_1hot,
batch["extra_msa_has_deletion"].unsqueeze(-1),
batch["extra_msa_deletion_value"].unsqueeze(-1),
]
return torch.cat(msa_feat, dim=-1)
from __future__ import annotations
from typing import Tuple, Any, Sequence, Callable, Optional, Iterable
import numpy as np
import torch
def zero_translation(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device("cpu"),
requires_grad: bool = False,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad
)
return trans
# pylint: disable=bad-whitespace
_QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
_QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr
_QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii
_QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj
_QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk
_QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij
_QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik
_QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk
_QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir
_QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr
_QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr
_QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9)
_QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT)
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC)
class Rotation:
def __init__(
self,
mat: torch.Tensor,
):
if mat.shape[-2:] != (3, 3):
raise ValueError(f"incorrect rotation shape: {mat.shape}")
self._mat = mat
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device("cpu"),
requires_grad: bool = False,
) -> Rotation:
mat = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
mat = mat.view(*((1,) * len(shape)), 3, 3)
mat = mat.expand(*shape, -1, -1)
return Rotation(mat)
@staticmethod
def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return (a.float() @ b.float()).type(a.dtype)
@staticmethod
def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype)
def __getitem__(self, index: Any) -> Rotation:
if not isinstance(index, tuple):
index = (index,)
return Rotation(mat=self._mat[index + (slice(None), slice(None))])
def __mul__(self, right: Any) -> Rotation:
if isinstance(right, (int, float)):
return Rotation(mat=self._mat * right)
elif isinstance(right, torch.Tensor):
return Rotation(mat=self._mat * right[..., None, None])
else:
raise TypeError(
f"multiplicand must be a tensor or a number, got {type(right)}."
)
def __rmul__(self, left: Any) -> Rotation:
return self.__mul__(left)
def __matmul__(self, other: Rotation) -> Rotation:
new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat)
return Rotation(mat=new_mat)
@property
def _inv_mat(self):
return self._mat.transpose(-1, -2)
@property
def rot_mat(self) -> torch.Tensor:
return self._mat
def invert(self) -> Rotation:
return Rotation(mat=self._inv_mat)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
return Rotation.mat_mul_vec(self._mat, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
return Rotation.mat_mul_vec(self._inv_mat, pts)
# inherit tensor behaviors
@property
def shape(self) -> torch.Size:
s = self._mat.shape[:-2]
return s
@property
def dtype(self) -> torch.dtype:
return self._mat.dtype
@property
def device(self) -> torch.device:
return self._mat.device
@property
def requires_grad(self) -> bool:
return self._mat.requires_grad
def unsqueeze(self, dim: int) -> Rotation:
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(mat=rot_mats)
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
mat = self._mat.view(self._mat.shape[:-2] + (9,))
mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1)
mat = mat.view(mat.shape[:-1] + (3, 3))
return Rotation(mat=mat)
@staticmethod
def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
rot_mats = [r.rot_mat for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
return Rotation(mat=rot_mats)
def cuda(self) -> Rotation:
return Rotation(mat=self._mat.cuda())
def to(
self, device: Optional[torch.device], dtype: Optional[torch.dtype]
) -> Rotation:
return Rotation(mat=self._mat.to(device=device, dtype=dtype))
def type(self, dtype: Optional[torch.dtype]) -> Rotation:
return Rotation(mat=self._mat.type(dtype))
def detach(self) -> Rotation:
return Rotation(mat=self._mat.detach())
class Frame:
def __init__(
self,
rotation: Optional[Rotation],
translation: Optional[torch.Tensor],
):
if rotation is None and translation is None:
rotation = Rotation.identity((0,))
translation = zero_translation((0,))
elif translation is None:
translation = zero_translation(
rotation.shape, rotation.dtype, rotation.device, rotation.requires_grad
)
elif rotation is None:
rotation = Rotation.identity(
translation.shape[:-1],
translation.dtype,
translation.device,
translation.requires_grad,
)
if (rotation.shape != translation.shape[:-1]) or (
rotation.device != translation.device
):
raise ValueError("RotationMatrix and translation incompatible")
self._r = rotation
self._t = translation
@staticmethod
def identity(
shape: Iterable[int],
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device("cpu"),
requires_grad: bool = False,
) -> Frame:
return Frame(
Rotation.identity(shape, dtype, device, requires_grad),
zero_translation(shape, dtype, device, requires_grad),
)
def __getitem__(
self,
index: Any,
) -> Frame:
if type(index) != tuple:
index = (index,)
return Frame(
self._r[index],
self._t[index + (slice(None),)],
)
def __mul__(
self,
right: torch.Tensor,
) -> Frame:
if not (isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
new_rots = self._r * right
new_trans = self._t * right[..., None]
return Frame(new_rots, new_trans)
def __rmul__(
self,
left: torch.Tensor,
) -> Frame:
return self.__mul__(left)
@property
def shape(self) -> torch.Size:
s = self._t.shape[:-1]
return s
@property
def device(self) -> torch.device:
return self._t.device
def get_rots(self) -> Rotation:
return self._r
def get_trans(self) -> torch.Tensor:
return self._t
def compose(
self,
other: Frame,
) -> Frame:
new_rot = self._r @ other._r
new_trans = self._r.apply(other._t) + self._t
return Frame(new_rot, new_trans)
def apply(
self,
pts: torch.Tensor,
) -> torch.Tensor:
rotated = self._r.apply(pts)
return rotated + self._t
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
pts = pts - self._t
return self._r.invert_apply(pts)
def invert(self) -> Frame:
rot_inv = self._r.invert()
trn_inv = rot_inv.apply(self._t)
return Frame(rot_inv, -1 * trn_inv)
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Frame:
new_rots = self._r.map_tensor_fn(fn)
new_trans = torch.stack(list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1)
return Frame(new_rots, new_trans)
def to_tensor_4x4(self) -> torch.Tensor:
tensor = self._t.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self._r.rot_mat
tensor[..., :3, 3] = self._t
tensor[..., 3, 3] = 1
return tensor
@staticmethod
def from_tensor_4x4(t: torch.Tensor) -> Frame:
if t.shape[-2:] != (4, 4):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(mat=t[..., :3, :3])
trans = t[..., :3, 3]
return Frame(rots, trans)
@staticmethod
def from_3_points(
p_neg_x_axis: torch.Tensor,
origin: torch.Tensor,
p_xy_plane: torch.Tensor,
eps: float = 1e-8,
) -> Frame:
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
e0[2] * e1[0] - e0[0] * e1[2],
e0[0] * e1[1] - e0[1] * e1[0],
]
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
rot_obj = Rotation(mat=rots)
return Frame(rot_obj, torch.stack(origin, dim=-1))
def unsqueeze(
self,
dim: int,
) -> Frame:
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._r.unsqueeze(dim)
trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1)
return Frame(rots, trans)
@staticmethod
def cat(
Ts: Sequence[Frame],
dim: int,
) -> Frame:
rots = Rotation.cat([T._r for T in Ts], dim)
trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1)
return Frame(rots, trans)
def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame:
return Frame(fn(self._r), self._t)
def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Frame:
return Frame(self._r, fn(self._t))
def scale_translation(self, trans_scale_factor: float) -> Frame:
fn = lambda t: t * trans_scale_factor
return self.apply_trans_fn(fn)
def stop_rot_gradient(self) -> Frame:
fn = lambda r: r.detach()
return self.apply_rot_fn(fn)
@staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
input_dtype = ca_xyz.dtype
n_xyz = n_xyz.float()
ca_xyz = ca_xyz.float()
c_xyz = c_xyz.float()
n_xyz = n_xyz - ca_xyz
c_xyz = c_xyz - ca_xyz
c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x**2 + c_y**2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
c1_rots[..., 0, 0] = cos_c1
c1_rots[..., 0, 1] = -1 * sin_c1
c1_rots[..., 1, 0] = sin_c1
c1_rots[..., 1, 1] = cos_c1
c1_rots[..., 2, 2] = 1
norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2)
sin_c2 = d_pair / norm
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c2_rots[..., 2, 0] = -1 * sin_c2
c2_rots[..., 2, 2] = cos_c2
c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots)
n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + n_y**2 + n_z**2)
sin_n = -n_z / norm
cos_n = n_y / norm
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
n_rots[..., 0, 0] = 1
n_rots[..., 1, 1] = cos_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 2, 1] = sin_n
n_rots[..., 2, 2] = cos_n
rots = Rotation.mat_mul_mat(n_rots, c_rots)
rots = rots.transpose(-1, -2)
rot_obj = Rotation(mat=rots.type(input_dtype))
return Frame(rot_obj, ca_xyz.type(input_dtype))
def cuda(self) -> Frame:
return Frame(self._r.cuda(), self._t.cuda())
@property
def dtype(self) -> torch.dtype:
assert self._r.dtype == self._t.dtype
return self._r.dtype
def type(self, dtype) -> Frame:
return Frame(self._r.type(dtype), self._t.type(dtype))
class Quaternion:
def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor):
if quaternion.shape[-1] != 4:
raise ValueError(f"incorrect quaternion shape: {quaternion.shape}")
self._q = quaternion
self._t = translation
@staticmethod
def identity(
shape: Iterable[int],
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device("cpu"),
requires_grad: bool = False,
) -> Quaternion:
trans = zero_translation(shape, dtype, device, requires_grad)
quats = torch.zeros(
(*shape, 4), dtype=dtype, device=device, requires_grad=requires_grad
)
with torch.no_grad():
quats[..., 0] = 1
return Quaternion(quats, trans)
def get_quats(self):
return self._q
def get_trans(self):
return self._t
def get_rot_mats(self):
quats = self.get_quats()
rot_mats = Quaternion.quat_to_rot(quats)
return rot_mats
@staticmethod
def quat_to_rot(normalized_quat):
global _QUAT_TO_ROT_tensor
dtype = normalized_quat.dtype
normalized_quat = normalized_quat.float()
if _QUAT_TO_ROT_tensor.device != normalized_quat.device:
_QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to(normalized_quat.device)
rot_tensor = torch.sum(
_QUAT_TO_ROT_tensor
* normalized_quat[..., :, None, None]
* normalized_quat[..., None, :, None],
dim=(-3, -2),
)
rot_tensor = rot_tensor.type(dtype)
rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3)
return rot_tensor
@staticmethod
def normalize_quat(quats):
dtype = quats.dtype
quats = quats.float()
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
quats = quats.type(dtype)
return quats
@staticmethod
def quat_multiply_by_vec(quat, vec):
dtype = quat.dtype
quat = quat.float()
vec = vec.float()
global _QUAT_MULTIPLY_BY_VEC_tensor
if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device:
_QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to(quat.device)
mat = _QUAT_MULTIPLY_BY_VEC_tensor
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None],
dim=(-3, -2),
).type(dtype)
def compose_q_update_vec(
self, q_update_vec: torch.Tensor, normalize_quats: bool = True
) -> torch.Tensor:
quats = self.get_quats()
new_quats = quats + Quaternion.quat_multiply_by_vec(quats, q_update_vec)
if normalize_quats:
new_quats = Quaternion.normalize_quat(new_quats)
return new_quats
def compose_update_vec(
self,
update_vec: torch.Tensor,
pre_rot_mat: Rotation,
) -> Quaternion:
q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:]
new_quats = self.compose_q_update_vec(q_vec)
trans_update = pre_rot_mat.apply(t_vec)
new_trans = self._t + trans_update
return Quaternion(new_quats, new_trans)
def stop_rot_gradient(self) -> Quaternion:
return Quaternion(self._q.detach(), self._t)
import math
import torch
import torch.nn as nn
from typing import Tuple
from .common import Linear, SimpleModuleList, residual
from .attentions import gen_attn_mask
from unifold.data.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from .frame import Rotation, Frame, Quaternion
from unicore.utils import (
one_hot,
dict_multimap,
permute_final_dims,
)
from unicore.modules import LayerNorm, softmax_dropout
def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
weights.fill_(softplus_inverse_1)
def torsion_angles_to_frames(
frame: Frame,
alpha: torch.Tensor,
aatype: torch.Tensor,
default_frames: torch.Tensor,
):
default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...])
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Frame(Rotation(mat=all_rots), None)
all_frames = default_frame.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Frame.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
all_frames_to_global = frame[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
frame: Frame,
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
group_mask = group_idx[aatype, ...]
group_mask = one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)
t_atoms_to_global = frame[..., None, :] * group_mask
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
return pred_positions
class SideChainAngleResnetIteration(nn.Module):
def __init__(self, d_hid):
super(SideChainAngleResnetIteration, self).__init__()
self.d_hid = d_hid
self.linear_1 = Linear(self.d_hid, self.d_hid, init="relu")
self.act = nn.GELU()
self.linear_2 = Linear(self.d_hid, self.d_hid, init="final")
def forward(self, s: torch.Tensor) -> torch.Tensor:
x = self.act(s)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return residual(s, x, self.training)
class SidechainAngleResnet(nn.Module):
def __init__(self, d_in, d_hid, num_blocks, num_angles):
super(SidechainAngleResnet, self).__init__()
self.linear_in = Linear(d_in, d_hid)
self.act = nn.GELU()
self.linear_initial = Linear(d_in, d_hid)
self.layers = SimpleModuleList()
for _ in range(num_blocks):
self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid))
self.linear_out = Linear(d_hid, num_angles * 2)
def forward(
self, s: torch.Tensor, initial_s: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
initial_s = self.linear_initial(self.act(initial_s))
s = self.linear_in(self.act(s))
s = s + initial_s
for layer in self.layers:
s = layer(s)
s = self.linear_out(self.act(s))
s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s.float() ** 2, dim=-1, keepdim=True),
min=1e-12,
)
)
s = s.float() / norm_denom
return unnormalized_s, s.type(unnormalized_s.dtype)
class InvariantPointAttention(nn.Module):
def __init__(
self,
d_single: int,
d_pair: int,
d_hid: int,
num_heads: int,
num_qk_points: int,
num_v_points: int,
separate_kv: bool = False,
bias: bool = True,
eps: float = 1e-8,
):
super(InvariantPointAttention, self).__init__()
self.d_hid = d_hid
self.num_heads = num_heads
self.num_qk_points = num_qk_points
self.num_v_points = num_v_points
self.eps = eps
hc = self.d_hid * self.num_heads
self.linear_q = Linear(d_single, hc, bias=bias)
self.separate_kv = separate_kv
if self.separate_kv:
self.linear_k = Linear(d_single, hc, bias=bias)
self.linear_v = Linear(d_single, hc, bias=bias)
else:
self.linear_kv = Linear(d_single, 2 * hc, bias=bias)
hpq = self.num_heads * self.num_qk_points * 3
self.linear_q_points = Linear(d_single, hpq)
hpk = self.num_heads * self.num_qk_points * 3
hpv = self.num_heads * self.num_v_points * 3
if self.separate_kv:
self.linear_k_points = Linear(d_single, hpk)
self.linear_v_points = Linear(d_single, hpv)
else:
hpkv = hpk + hpv
self.linear_kv_points = Linear(d_single, hpkv)
self.linear_b = Linear(d_pair, self.num_heads)
self.head_weights = nn.Parameter(torch.zeros((num_heads)))
ipa_point_weights_init_(self.head_weights)
concat_out_dim = self.num_heads * (d_pair + self.d_hid + self.num_v_points * 4)
self.linear_out = Linear(concat_out_dim, d_single, init="final")
self.softplus = nn.Softplus()
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
f: Frame,
square_mask: torch.Tensor,
) -> torch.Tensor:
q = self.linear_q(s)
q = q.view(q.shape[:-1] + (self.num_heads, -1))
if self.separate_kv:
k = self.linear_k(s)
v = self.linear_v(s)
k = k.view(k.shape[:-1] + (self.num_heads, -1))
v = v.view(v.shape[:-1] + (self.num_heads, -1))
else:
kv = self.linear_kv(s)
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
k, v = torch.split(kv, self.d_hid, dim=-1)
q_pts = self.linear_q_points(s)
def process_points(pts, no_points):
shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3)
if self.separate_kv:
# alphafold-multimer uses different layout
pts = pts.view(pts.shape[:-1] + (self.num_heads, no_points * 3))
pts = torch.split(pts, pts.shape[-1] // 3, dim=-1)
pts = torch.stack(pts, dim=-1).view(*shape)
pts = f[..., None].apply(pts)
pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3))
return pts
q_pts = process_points(q_pts, self.num_qk_points)
if self.separate_kv:
k_pts = self.linear_k_points(s)
v_pts = self.linear_v_points(s)
k_pts = process_points(k_pts, self.num_qk_points)
v_pts = process_points(v_pts, self.num_v_points)
else:
kv_pts = self.linear_kv_points(s)
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = f[..., None].apply(kv_pts)
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
k_pts, v_pts = torch.split(
kv_pts, [self.num_qk_points, self.num_v_points], dim=-2
)
bias = self.linear_b(z)
attn = torch.matmul(
permute_final_dims(q, (1, 0, 2)),
permute_final_dims(k, (1, 2, 0)),
)
if self.training:
attn = attn * math.sqrt(1.0 / (3 * self.d_hid))
attn = attn + (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1)))
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att.float() ** 2
else:
attn *= math.sqrt(1.0 / (3 * self.d_hid))
attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1)))
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att *= pt_att
pt_att = pt_att.sum(dim=-1)
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
)
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.num_qk_points * 9.0 / 2))
)
pt_att *= head_weights * (-0.5)
pt_att = torch.sum(pt_att, dim=-1)
pt_att = permute_final_dims(pt_att, (2, 0, 1))
attn += square_mask
attn = softmax_dropout(attn, 0, self.training, bias=pt_att.type(attn.dtype))
del pt_att, q_pts, k_pts, bias
o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3)
o = o.contiguous().view(*o.shape[:-2], -1)
del q, k, v
o_pts = torch.sum(
(
attn[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
o_pts = permute_final_dims(o_pts, (2, 0, 3, 1))
o_pts = f[..., None, None].invert_apply(o_pts)
if self.training:
o_pts_norm = torch.sqrt(torch.sum(o_pts.float() ** 2, dim=-1) + self.eps).type(
o_pts.dtype
)
else:
o_pts_norm = torch.sqrt(torch.sum(o_pts ** 2, dim=-1) + self.eps).type(
o_pts.dtype
)
o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1)
o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3)
o_pair = torch.matmul(attn.transpose(-2, -3), z)
o_pair = o_pair.view(*o_pair.shape[:-2], -1)
s = self.linear_out(
torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair), dim=-1)
)
return s
class BackboneUpdate(nn.Module):
def __init__(self, d_single):
super(BackboneUpdate, self).__init__()
self.linear = Linear(d_single, 6, init="final")
def forward(self, s: torch.Tensor):
return self.linear(s)
class StructureModuleTransitionLayer(nn.Module):
def __init__(self, c):
super(StructureModuleTransitionLayer, self).__init__()
self.linear_1 = Linear(c, c, init="relu")
self.linear_2 = Linear(c, c, init="relu")
self.act = nn.GELU()
self.linear_3 = Linear(c, c, init="final")
def forward(self, s):
s_old = s
s = self.linear_1(s)
s = self.act(s)
s = self.linear_2(s)
s = self.act(s)
s = self.linear_3(s)
s = residual(s_old, s, self.training)
return s
class StructureModuleTransition(nn.Module):
def __init__(self, c, num_layers, dropout_rate):
super(StructureModuleTransition, self).__init__()
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.layers = SimpleModuleList()
for _ in range(self.num_layers):
self.layers.append(StructureModuleTransitionLayer(c))
self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = LayerNorm(c)
def forward(self, s):
for layer in self.layers:
s = layer(s)
s = self.dropout(s)
s = self.layer_norm(s)
return s
class StructureModule(nn.Module):
def __init__(
self,
d_single,
d_pair,
d_ipa,
d_angle,
num_heads_ipa,
num_qk_points,
num_v_points,
dropout_rate,
num_blocks,
no_transition_layers,
num_resnet_blocks,
num_angles,
trans_scale_factor,
separate_kv,
ipa_bias,
epsilon,
inf,
**kwargs,
):
super(StructureModule, self).__init__()
self.num_blocks = num_blocks
self.trans_scale_factor = trans_scale_factor
self.default_frames = None
self.group_idx = None
self.atom_mask = None
self.lit_positions = None
self.inf = inf
self.layer_norm_s = LayerNorm(d_single)
self.layer_norm_z = LayerNorm(d_pair)
self.linear_in = Linear(d_single, d_single)
self.ipa = InvariantPointAttention(
d_single,
d_pair,
d_ipa,
num_heads_ipa,
num_qk_points,
num_v_points,
separate_kv=separate_kv,
bias=ipa_bias,
eps=epsilon,
)
self.ipa_dropout = nn.Dropout(dropout_rate)
self.layer_norm_ipa = LayerNorm(d_single)
self.transition = StructureModuleTransition(
d_single,
no_transition_layers,
dropout_rate,
)
self.bb_update = BackboneUpdate(d_single)
self.angle_resnet = SidechainAngleResnet(
d_single,
d_angle,
num_resnet_blocks,
num_angles,
)
def forward(
self,
s,
z,
aatype,
mask=None,
):
if mask is None:
mask = s.new_ones(s.shape[:-1])
# generate square mask
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3)
s = self.layer_norm_s(s)
z = self.layer_norm_z(z)
initial_s = s
s = self.linear_in(s)
quat_encoder = Quaternion.identity(
s.shape[:-1],
s.dtype,
s.device,
requires_grad=False,
)
backb_to_global = Frame(
Rotation(
mat=quat_encoder.get_rot_mats(),
),
quat_encoder.get_trans(),
)
outputs = []
for i in range(self.num_blocks):
s = residual(s, self.ipa(s, z, backb_to_global, square_mask), self.training)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# update quaternion encoder
# use backb_to_global to avoid quat-to-rot conversion
quat_encoder = quat_encoder.compose_update_vec(
self.bb_update(s), pre_rot_mat=backb_to_global.get_rots()
)
# initial_s is always used to update the backbone
unnormalized_angles, angles = self.angle_resnet(s, initial_s)
# convert quaternion to rotation matrix
backb_to_global = Frame(
Rotation(
mat=quat_encoder.get_rot_mats(),
),
quat_encoder.get_trans(),
)
if i == self.num_blocks - 1:
all_frames_to_global = self.torsion_angles_to_frames(
backb_to_global.scale_translation(self.trans_scale_factor),
angles,
aatype,
)
pred_positions = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
aatype,
)
preds = {
"frames": backb_to_global.scale_translation(
self.trans_scale_factor
).to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
}
outputs.append(preds)
if i < (self.num_blocks - 1):
# stop gradient in iteration
quat_encoder = quat_encoder.stop_rot_gradient()
backb_to_global = backb_to_global.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
outputs["sidechain_frames"] = all_frames_to_global.to_tensor_4x4()
outputs["positions"] = pred_positions
outputs["single"] = s
return outputs
def _init_residue_constants(self, float_dtype, device):
if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, frame, alpha, aatype):
self._init_residue_constants(alpha.dtype, alpha.device)
return torsion_angles_to_frames(frame, alpha, aatype, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(self, frame, aatype):
self._init_residue_constants(frame.get_rots().dtype, frame.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
frame,
aatype,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
from functools import partial
from typing import Optional, List, Tuple
import math
import torch
import torch.nn as nn
from .attentions import Attention
from .common import (
SimpleModuleList,
residual,
bias_dropout_residual,
tri_mul_residual,
)
from .common import Linear, Transition, chunk_layer
from .attentions import (
gen_attn_mask,
TriangleAttentionStarting,
TriangleAttentionEnding,
)
from .featurization import build_template_pair_feat_v2
from .triangle_multiplication import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from unicore.utils import (
checkpoint_sequential,
permute_final_dims,
tensor_tree_map
)
from unicore.modules import LayerNorm
class TemplatePointwiseAttention(nn.Module):
def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs):
super(TemplatePointwiseAttention, self).__init__()
self.inf = inf
self.mha = Attention(
d_pair,
d_template,
d_template,
d_hid,
num_heads,
gating=False,
)
def _chunk(
self,
z: torch.Tensor,
t: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
"q": z,
"k": t,
"v": t,
"mask": mask,
}
return chunk_layer(
self.mha,
mha_inputs,
chunk_size=chunk_size,
num_batch_dims=len(z.shape[:-2]),
)
def forward(
self,
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
if template_mask is None:
template_mask = t.new_ones(t.shape[:-3])
mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None, None, :]
z = z.unsqueeze(-2)
t = permute_final_dims(t, (1, 2, 0, 3))
if chunk_size is not None:
z = self._chunk(z, t, mask, chunk_size)
else:
z = self.mha(z, t, t, mask=mask)
z = z.squeeze(-2)
return z
class TemplateProjection(nn.Module):
def __init__(self, d_template, d_pair, **kwargs):
super(TemplateProjection, self).__init__()
self.d_pair = d_pair
self.act = nn.ReLU()
self.output_linear = Linear(d_template, d_pair, init="relu")
def forward(self, t, z) -> torch.Tensor:
if t is None:
# handle for non-template case
shape = z.shape
shape[-1] = self.d_pair
t = torch.zeros(shape, dtype=z.dtype, device=z.device)
t = self.act(t)
z_t = self.output_linear(t)
return z_t
class TemplatePairStackBlock(nn.Module):
def __init__(
self,
d_template: int,
d_hid_tri_att: int,
d_hid_tri_mul: int,
num_heads: int,
pair_transition_n: int,
dropout_rate: float,
tri_attn_first: bool,
inf: float,
**kwargs,
):
super(TemplatePairStackBlock, self).__init__()
self.tri_att_start = TriangleAttentionStarting(
d_template,
d_hid_tri_att,
num_heads,
)
self.tri_att_end = TriangleAttentionEnding(
d_template,
d_hid_tri_att,
num_heads,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
d_template,
d_hid_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
d_template,
d_hid_tri_mul,
)
self.pair_transition = Transition(
d_template,
pair_transition_n,
)
self.tri_attn_first = tri_attn_first
self.dropout = dropout_rate
self.row_dropout_share_dim = -3
self.col_dropout_share_dim = -2
def forward(
self,
s: torch.Tensor,
mask: torch.Tensor,
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: Optional[int] = None,
block_size: Optional[int] = None,
):
if self.tri_attn_first:
s = bias_dropout_residual(
self.tri_att_start,
s,
self.tri_att_start(
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size
),
self.row_dropout_share_dim,
self.dropout,
self.training,
)
s = bias_dropout_residual(
self.tri_att_end,
s,
self.tri_att_end(s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.dropout,
self.training,
)
s = tri_mul_residual(
self.tri_mul_out,
s,
self.tri_mul_out(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)
s = tri_mul_residual(
self.tri_mul_in,
s,
self.tri_mul_in(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)
else:
s = tri_mul_residual(
self.tri_mul_out,
s,
self.tri_mul_out(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)
s = tri_mul_residual(
self.tri_mul_in,
s,
self.tri_mul_in(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)
s = bias_dropout_residual(
self.tri_att_start,
s,
self.tri_att_start(
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size
),
self.row_dropout_share_dim,
self.dropout,
self.training,
)
s = bias_dropout_residual(
self.tri_att_end,
s,
self.tri_att_end(s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.dropout,
self.training,
)
s = residual(
s,
self.pair_transition(
s,
chunk_size=chunk_size,
),
self.training
)
return s
class TemplatePairStack(nn.Module):
def __init__(
self,
d_template,
d_hid_tri_att,
d_hid_tri_mul,
num_blocks,
num_heads,
pair_transition_n,
dropout_rate,
tri_attn_first,
inf=1e9,
**kwargs,
):
super(TemplatePairStack, self).__init__()
self.blocks = SimpleModuleList()
for _ in range(num_blocks):
self.blocks.append(
TemplatePairStackBlock(
d_template=d_template,
d_hid_tri_att=d_hid_tri_att,
d_hid_tri_mul=d_hid_tri_mul,
num_heads=num_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
inf=inf,
tri_attn_first=tri_attn_first,
)
)
self.layer_norm = LayerNorm(d_template)
def forward(
self,
single_templates: Tuple[torch.Tensor],
mask: torch.tensor,
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
templ_dim: int,
chunk_size: int,
block_size: int,
return_mean: bool,
):
def one_template(i):
(s,) = checkpoint_sequential(
functions=[
partial(
b,
mask=mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size,
)
for b in self.blocks
],
input=(single_templates[i],),
)
return s
n_templ = len(single_templates)
if n_templ > 0:
new_single_templates = [one_template(0)]
if return_mean:
t = self.layer_norm(new_single_templates[0])
for i in range(1, n_templ):
s = one_template(i)
if return_mean:
t = residual(t, self.layer_norm(s), self.training)
else:
new_single_templates.append(s)
if return_mean:
if n_templ > 0:
t /= n_templ
else:
t = None
else:
t = torch.cat(
[s.unsqueeze(templ_dim) for s in new_single_templates], dim=templ_dim
)
t = self.layer_norm(t)
return t
from functools import partialmethod
from typing import Optional, List
import torch
import torch.nn as nn
from .common import Linear
from unicore.utils import (
permute_final_dims,
)
from unicore.modules import (
LayerNorm,
)
class TriangleMultiplication(nn.Module):
def __init__(self, d_pair, d_hid, outgoing=True):
super(TriangleMultiplication, self).__init__()
self.outgoing = outgoing
self.linear_ab_p = Linear(d_pair, d_hid * 2)
self.linear_ab_g = Linear(d_pair, d_hid * 2, init="gating")
self.linear_g = Linear(d_pair, d_pair, init="gating")
self.linear_z = Linear(d_hid, d_pair, init="final")
self.layer_norm_in = LayerNorm(d_pair)
self.layer_norm_out = LayerNorm(d_hid)
self._alphafold_original_mode = False
def _chunk_2d(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
block_size: int = None,
) -> torch.Tensor:
# avoid too small chunk size
# block_size = max(block_size, 256)
new_z = z.new_zeros(z.shape)
dim1 = z.shape[-3]
def _slice_linear(z, linear: Linear, a=True):
d_hid = linear.bias.shape[0] // 2
index = 0 if a else d_hid
p = (
nn.functional.linear(z, linear.weight[index : index + d_hid])
+ linear.bias[index : index + d_hid]
)
return p
def _chunk_projection(z, mask, a=True):
p = _slice_linear(z, self.linear_ab_p, a) * mask
p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a))
return p
num_chunk = (dim1 + block_size - 1) // block_size
for i in range(num_chunk):
chunk_start = i * block_size
chunk_end = min(chunk_start + block_size, dim1)
if self.outgoing:
a_chunk = _chunk_projection(
z[..., chunk_start:chunk_end, :, :],
mask[..., chunk_start:chunk_end, :, :],
a=True,
)
a_chunk = permute_final_dims(a_chunk, (2, 0, 1))
else:
a_chunk = _chunk_projection(
z[..., :, chunk_start:chunk_end, :],
mask[..., :, chunk_start:chunk_end, :],
a=True,
)
a_chunk = a_chunk.transpose(-1, -3)
for j in range(num_chunk):
j_chunk_start = j * block_size
j_chunk_end = min(j_chunk_start + block_size, dim1)
if self.outgoing:
b_chunk = _chunk_projection(
z[..., j_chunk_start:j_chunk_end, :, :],
mask[..., j_chunk_start:j_chunk_end, :, :],
a=False,
)
b_chunk = b_chunk.transpose(-1, -3)
else:
b_chunk = _chunk_projection(
z[..., :, j_chunk_start:j_chunk_end, :],
mask[..., :, j_chunk_start:j_chunk_end, :],
a=False,
)
b_chunk = permute_final_dims(b_chunk, (2, 0, 1))
x_chunk = torch.matmul(a_chunk, b_chunk)
del b_chunk
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
x_chunk = self.layer_norm_out(x_chunk)
x_chunk = self.linear_z(x_chunk)
x_chunk *= torch.sigmoid(
self.linear_g(
z[..., chunk_start:chunk_end, j_chunk_start:j_chunk_end, :]
)
)
new_z[
..., chunk_start:chunk_end, j_chunk_start:j_chunk_end, :
] = x_chunk
del x_chunk
del a_chunk
return new_z
def forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
block_size=None,
) -> torch.Tensor:
mask = mask.unsqueeze(-1)
if not self._alphafold_original_mode:
# divided by 1/sqrt(dim) for numerical stability
mask = mask * (mask.shape[-2] ** -0.5)
z = self.layer_norm_in(z)
if not self.training and block_size is not None:
return self._chunk_2d(z, mask, block_size=block_size)
g = nn.functional.linear(z, self.linear_g.weight)
if self.training:
ab = self.linear_ab_p(z) * mask * torch.sigmoid(self.linear_ab_g(z))
else:
ab = self.linear_ab_p(z)
ab *= mask
ab *= torch.sigmoid(self.linear_ab_g(z))
a, b = torch.chunk(ab, 2, dim=-1)
del z, ab
if self.outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = b.transpose(-1, -3)
else:
b = permute_final_dims(b, (2, 0, 1))
a = a.transpose(-1, -3)
x = torch.matmul(a, b)
del a, b
x = permute_final_dims(x, (1, 2, 0))
x = self.layer_norm_out(x)
x = nn.functional.linear(x, self.linear_z.weight)
return x, g
def get_output_bias(self):
return self.linear_z.bias, self.linear_g.bias
class TriangleMultiplicationOutgoing(TriangleMultiplication):
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True)
class TriangleMultiplicationIncoming(TriangleMultiplication):
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False)
""" Scripts for MSA & template searching. """
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment