Commit 68828c49 authored by Christina Floristean's avatar Christina Floristean
Browse files

Multimer v3 updates

parent 736f27fd
import re
import copy
import importlib
import ml_collections as mlc
......@@ -155,6 +156,18 @@ def model_config(
elif "multimer" in name:
c.globals.is_multimer = True
c.loss.masked_msa.num_classes = 22
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
c.model.evoformer.num_msa = 252
c.model.evoformer.num_extra_msa= 1152
c.model.evoformer.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
elif name == 'model_5_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
for k,v in multimer_model_config_update.items():
c.model[k] = v
......@@ -438,6 +451,7 @@ config = mlc.ConfigDict(
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
......@@ -487,6 +501,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"fuse_projection_weights": False,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
......@@ -510,6 +525,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
......@@ -671,6 +687,7 @@ multimer_model_config_update = {
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
......@@ -701,6 +718,7 @@ multimer_model_config_update = {
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
......@@ -723,6 +741,7 @@ multimer_model_config_update = {
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
......
......@@ -18,6 +18,7 @@ import torch
import torch.nn as nn
from typing import Tuple, Sequence, Optional
from functools import partial
from abc import ABC, abstractmethod
from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
......@@ -36,6 +37,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
FusedTriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing
)
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer, ChunkSizeTuner
......@@ -127,19 +130,30 @@ class PairStack(nn.Module):
no_heads_pair: int,
transition_n: int,
pair_dropout: float,
fuse_projection_weights: bool,
inf: float,
eps: float
):
super(PairStack, self).__init__()
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
c_z,
c_hidden_mul,
)
if fuse_projection_weights:
self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
)
self.tri_mul_in = FusedTriangleMultiplicationIncoming(
c_z,
c_hidden_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
c_z,
c_hidden_mul,
)
self.tri_att_start = TriangleAttention(
c_z,
......@@ -162,15 +176,14 @@ class PairStack(nn.Module):
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
def forward(self,
input_tensors: Sequence[torch.Tensor],
z: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_offload_inference: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
......@@ -179,8 +192,6 @@ class PairStack(nn.Module):
if (_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
tmu_update = self.tri_mul_out(
z,
mask=pair_mask,
......@@ -223,8 +234,7 @@ class PairStack(nn.Module):
z = z.transpose(-2, -3)
if (inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = z.contiguous()
z = add(z,
self.ps_dropout_row_layer(
......@@ -242,8 +252,7 @@ class PairStack(nn.Module):
z = z.transpose(-2, -3)
if (inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = z.contiguous()
z = add(z,
self.pair_transition(
......@@ -252,20 +261,10 @@ class PairStack(nn.Module):
inplace=inplace_safe,
)
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
return z
class EvoformerBlock(nn.Module):
class MSABlock(nn.Module, ABC):
@abstractmethod
def __init__(self,
c_m: int,
c_z: int,
......@@ -279,10 +278,11 @@ class EvoformerBlock(nn.Module):
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__()
super(MSABlock, self).__init__()
self.opm_first = opm_first
......@@ -294,13 +294,6 @@ class EvoformerBlock(nn.Module):
inf=inf,
)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.msa_transition = MSATransition(
......@@ -321,10 +314,103 @@ class EvoformerBlock(nn.Module):
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps
)
def _compute_opm(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
_offload_inference: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
m, z = input_tensors
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
return m, z
@abstractmethod
def forward(self,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pass
class EvoformerBlock(MSABlock):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__(c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
def forward(self,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
......@@ -354,26 +440,13 @@ class EvoformerBlock(nn.Module):
m, z = input_tensors
if self.opm_first:
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
del m, z
z = add(z, opm, inplace=inplace_safe)
del opm
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
m = add(m,
self.msa_dropout_layer(
......@@ -417,21 +490,18 @@ class EvoformerBlock(nn.Module):
)
if not self.opm_first:
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if (not inplace_safe):
input_tensors = [m, z]
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
del m, z
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
z = add(z, opm, inplace=inplace_safe)
del opm
elif (_offload_inference and inplace_safe):
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
......@@ -445,21 +515,31 @@ class EvoformerBlock(nn.Module):
del m, z
m, z = self.pair_stack(
input_tensors,
z = self.pair_stack(
z=input_tensors[1],
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size
)
m = input_tensors[0]
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
class ExtraMSABlock(nn.Module):
class ExtraMSABlock(MSABlock):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
......@@ -479,23 +559,29 @@ class ExtraMSABlock(nn.Module):
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
ckpt: bool,
):
super(ExtraMSABlock, self).__init__()
super(ExtraMSABlock, self).__init__(c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps)
self.opm_first = opm_first
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
......@@ -504,30 +590,6 @@ class ExtraMSABlock(nn.Module):
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.pair_stack = PairStack(
c_z=c_z,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
......@@ -540,7 +602,7 @@ class ExtraMSABlock(nn.Module):
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
......@@ -553,28 +615,15 @@ class ExtraMSABlock(nn.Module):
m, z = input_tensors
if self.opm_first:
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
del m, z
z = add(z, opm, inplace=inplace_safe)
del opm
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
m = add(m,
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m.clone() if torch.is_grad_enabled() else m,
......@@ -625,37 +674,52 @@ class ExtraMSABlock(nn.Module):
)
if not self.opm_first:
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if (not inplace_safe):
input_tensors = [m, z]
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
del m, z
z = add(z, opm, inplace=inplace_safe)
del opm
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
device = input_tensors[0].device
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
if (not inplace_safe):
input_tensors = [m, z]
del m, z
m, z = self.pair_stack(
input_tensors,
z = self.pair_stack(
input_tensors[1],
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size
)
m = input_tensors[0]
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
if (torch.is_grad_enabled() and self.ckpt):
......@@ -690,6 +754,7 @@ class EvoformerStack(nn.Module):
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
fuse_projection_weights: bool,
blocks_per_ckpt: int,
inf: float,
eps: float,
......@@ -755,6 +820,7 @@ class EvoformerStack(nn.Module):
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps,
)
......@@ -940,6 +1006,7 @@ class ExtraMSAStack(nn.Module):
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
ckpt: bool,
......@@ -966,6 +1033,7 @@ class ExtraMSAStack(nn.Module):
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps,
ckpt=False,
......
......@@ -178,7 +178,7 @@ class PointProjection(nn.Module):
def forward(self,
activations: torch.Tensor,
rigids: Union[Rigid, Rigid3Array],
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array], torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
......@@ -398,20 +398,14 @@ class InvariantPointAttention(nn.Module):
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
if self.is_multimer:
pt_att = q_pts.unsqueeze(-3) - k_pts.unsqueeze(-4)
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
# [*, N_res, N_res, H, P_q]
pt_att = sum([c ** 2 for c in pt_att])
if (inplace_safe):
pt_att *= pt_att
else:
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
if (inplace_safe):
pt_att *= pt_att
else:
pt_att = pt_att ** 2
pt_att = sum(torch.unbind(pt_att, dim=-1))
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
......@@ -427,6 +421,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
......@@ -460,51 +455,35 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
if self.is_multimer:
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, N_res, H, P_v]
o_pt = v_pts[..., None, :, :, :] * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v]
o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
# [*, H, 3, N_res, P_v]
if (inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
for v in torch.unbind(v_pts, dim=-3)
]
o_pt = torch.stack(o_pt, dim=-3)
else:
# [*, H, 3, N_res, P_v]
if (inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
for v in torch.unbind(v_pts, dim=-3)
]
o_pt = torch.stack(o_pt, dim=-3)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
if (_offload_inference):
z[0] = z[0].to(o_pt.device)
......
......@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing,
FusedTriangleMultiplicationIncoming
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import (
......@@ -155,6 +157,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n: int,
dropout_rate: float,
tri_mul_first: bool,
fuse_projection_weights: bool,
inf: float,
**kwargs,
):
......@@ -185,14 +188,24 @@ class TemplatePairStackBlock(nn.Module):
inf=inf,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
if fuse_projection_weights:
self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = FusedTriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
self.pair_transition = PairTransition(
self.c_t,
......@@ -329,6 +342,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n,
dropout_rate,
tri_mul_first,
fuse_projection_weights,
blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9,
......@@ -366,6 +380,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
)
self.blocks.append(block)
......
......@@ -15,6 +15,7 @@
from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
......@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
@abstractmethod
def __init__(self, c_z, c_hidden, _outgoing):
"""
Args:
c_z:
......@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__()
super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
......@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
return permute_final_dims(p, (1, 2, 0))
@abstractmethod
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
......@@ -425,3 +463,149 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
Implements Algorithm 12.
"""
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(FusedTriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_ab_p = Linear(self.c_z, self.c_hidden * 2)
self.linear_ab_g = Linear(self.c_z, self.c_hidden * 2, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
_inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask):
pair = self.layer_norm_in(pair)
p = self.linear_ab_g(pair)
p.sigmoid_()
p *= self.linear_ab_p(pair)
p *= mask
return p
def compute_projection(pair, mask):
p = compute_projection_helper(pair, mask)
a = p[..., :self.c_hidden]
b = p[..., self.c_hidden:]
return a, b
a, b = compute_projection(z, mask)
x = self._combine_projections(a, b, _inplace_chunk_size=_inplace_chunk_size)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g.sigmoid_()
x *= g
if (with_add):
z += x
else:
z = x
return z
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if (inplace_safe):
x = self._inference_forward(
z,
mask,
_inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
ab = mask
ab = ab * self.sigmoid(self.linear_ab_g(z))
ab = ab * self.linear_ab_p(z)
a = ab[..., :self.c_hidden]
b = ab[..., self.c_hidden:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a = a / a.std()
b = b / b.std()
if (is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
x = x * g
return x
class FusedTriangleMultiplicationOutgoing(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=True)
class FusedTriangleMultiplicationIncoming(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=False)
......@@ -189,7 +189,7 @@ def torsion_angles_to_frames(
rrgdf: torch.Tensor,
):
rigid_type = Rigid if isinstance(r, Rigid) else rigid_matrix_vector.Rigid3Array
rigid_type = type(r)
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
......@@ -217,18 +217,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.shape + (3, 3))
all_rots = alpha.new_zeros(default_r.shape + (4, 4))
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots[..., 2, 1:3] = alpha
if isinstance(r, Rigid):
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
else:
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
all_rots = rigid_type.from_tensor_4x4(all_rots)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
......@@ -283,16 +279,11 @@ def frames_and_literature_positions_to_atom14_pos(
)
# [*, N, 14]
atom_mask = atom_mask[aatype, ...]
if isinstance(r, Rigid):
atom_mask = atom_mask.unsqueeze(-1)
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
if isinstance(pred_positions, vector.Vec3Array):
return pred_positions.to_tensor()
return pred_positions
......@@ -67,14 +67,17 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> vector.Vec3Array:
return self.apply_to_point(vector.Vec3Array.from_array(point))
def apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, self.translation.clone())
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from enum import Enum
from dataclasses import dataclass
from functools import partial
......@@ -191,31 +192,47 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention": AttentionGatedParams(tri_att.mha),
}
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p),
"right_projection": LinearParams(tri_mul.linear_b_p),
"left_gate": LinearParams(tri_mul.linear_a_g),
"right_gate": LinearParams(tri_mul.linear_b_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
def TriMulOutParams(tri_mul, outgoing=True):
if re.fullmatch("^model_[1-5]_multimer_v3$", version):
d = {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_ab_p),
"gate": LinearParams(tri_mul.linear_ab_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
}
else:
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
if outgoing:
left_projection = LinearParams(tri_mul.linear_a_p)
right_projection = LinearParams(tri_mul.linear_b_p)
left_gate = LinearParams(tri_mul.linear_a_g)
right_gate = LinearParams(tri_mul.linear_b_g)
else:
left_projection = LinearParams(tri_mul.linear_b_p)
right_projection = LinearParams(tri_mul.linear_a_p)
left_gate = LinearParams(tri_mul.linear_b_g)
right_gate = LinearParams(tri_mul.linear_a_g)
d = {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": left_projection,
"right_projection": right_projection,
"left_gate": left_gate,
"right_gate": right_gate,
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p),
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
d.update({
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
})
return d
TriMulInParams = partial(TriMulOutParams, outgoing=False)
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
......
......@@ -56,16 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniclust30..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref30..."
bash "${SCRIPT_DIR}/download_uniref30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded."
......@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
......
......@@ -32,8 +32,8 @@ fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/mgnify"
# Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz"
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/mgy_clusters_2022_05.fa.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
......
......@@ -36,3 +36,7 @@ BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
# Keep only protein sequences.
grep --after-context=1 --no-group-separator '>.* mol:protein' "${ROOT_DIR}/pdb_seqres.txt" > "${ROOT_DIR}/pdb_seqres_filtered.txt"
mv "${ROOT_DIR}/pdb_seqres_filtered.txt" "${ROOT_DIR}/pdb_seqres.txt"
......@@ -30,8 +30,10 @@ if ! command -v aria2c &> /dev/null ; then
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}"
SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz"
ROOT_DIR="${DOWNLOAD_DIR}/uniref30"
# Mirror of:
# https://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/UniRef30_2021_03.tar.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/UniRef30_2021_03.tar.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
......
......@@ -2,7 +2,7 @@ import ml_collections as mlc
consts = mlc.ConfigDict(
{
"model": "model_1_multimer_v2", # monomer:model_1_ptm, multimer: model_1_multimer_v2
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -49,6 +50,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -67,6 +69,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout,
pair_stack_dropout,
opm_first,
fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......@@ -177,6 +180,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -194,6 +198,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout,
pair_stack_dropout,
opm_first,
fuse_projection_weights,
ckpt=False,
inf=inf,
eps=eps,
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -78,6 +79,7 @@ class TestTemplatePairStack(unittest.TestCase):
n_templ = consts.n_templ
n_res = consts.n_res
tri_mul_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
blocks_per_ckpt = None
chunk_size = 4
inf = 1e7
......@@ -92,6 +94,7 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import torch
import re
import numpy as np
import unittest
from openfold.model.triangular_multiplicative_update import *
......@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z = consts.c_z
c = 11
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
tm = FusedTriangleMultiplicationOutgoing(
c_z,
c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
n_res = consts.c_z
batch_size = consts.batch_size
......@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config.model.global_config,
name=name,
)
act = tri_mul(act=pair_act, mask=pair_mask)
act = tri_mul(pair_act, pair_mask)
return act
f = hk.transform(run_tri_mul)
......@@ -89,6 +96,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if incoming
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment