Commit 68828c49 authored by Christina Floristean's avatar Christina Floristean
Browse files

Multimer v3 updates

parent 736f27fd
import re
import copy import copy
import importlib import importlib
import ml_collections as mlc import ml_collections as mlc
...@@ -155,6 +156,18 @@ def model_config( ...@@ -155,6 +156,18 @@ def model_config(
elif "multimer" in name: elif "multimer" in name:
c.globals.is_multimer = True c.globals.is_multimer = True
c.loss.masked_msa.num_classes = 22 c.loss.masked_msa.num_classes = 22
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
c.model.evoformer.num_msa = 252
c.model.evoformer.num_extra_msa= 1152
c.model.evoformer.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
elif name == 'model_5_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
for k,v in multimer_model_config_update.items(): for k,v in multimer_model_config_update.items():
c.model[k] = v c.model[k] = v
...@@ -438,6 +451,7 @@ config = mlc.ConfigDict( ...@@ -438,6 +451,7 @@ config = mlc.ConfigDict(
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": False, "tri_mul_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
...@@ -487,6 +501,7 @@ config = mlc.ConfigDict( ...@@ -487,6 +501,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False, "opm_first": False,
"fuse_projection_weights": False,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
...@@ -510,6 +525,7 @@ config = mlc.ConfigDict( ...@@ -510,6 +525,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False, "opm_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
...@@ -671,6 +687,7 @@ multimer_model_config_update = { ...@@ -671,6 +687,7 @@ multimer_model_config_update = {
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": True, "tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9, "inf": 1e9,
}, },
...@@ -701,6 +718,7 @@ multimer_model_config_update = { ...@@ -701,6 +718,7 @@ multimer_model_config_update = {
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": True, "opm_first": True,
"fuse_projection_weights": True,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": True,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
...@@ -723,6 +741,7 @@ multimer_model_config_update = { ...@@ -723,6 +741,7 @@ multimer_model_config_update = {
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": True, "opm_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"inf": 1e9, "inf": 1e9,
......
This diff is collapsed.
...@@ -178,7 +178,7 @@ class PointProjection(nn.Module): ...@@ -178,7 +178,7 @@ class PointProjection(nn.Module):
def forward(self, def forward(self,
activations: torch.Tensor, activations: torch.Tensor,
rigids: Union[Rigid, Rigid3Array], rigids: Union[Rigid, Rigid3Array],
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array], torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training # TODO: Needs to run in high precision during training
points_local = self.linear(activations) points_local = self.linear(activations)
...@@ -398,20 +398,14 @@ class InvariantPointAttention(nn.Module): ...@@ -398,20 +398,14 @@ class InvariantPointAttention(nn.Module):
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
if self.is_multimer: pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = q_pts.unsqueeze(-3) - k_pts.unsqueeze(-4)
# [*, N_res, N_res, H, P_q] if (inplace_safe):
pt_att = sum([c ** 2 for c in pt_att]) pt_att *= pt_att
else: else:
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = pt_att ** 2
if (inplace_safe): pt_att = sum(torch.unbind(pt_att, dim=-1))
pt_att *= pt_att
else:
pt_att = pt_att ** 2
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view( head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1)) *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
...@@ -427,6 +421,7 @@ class InvariantPointAttention(nn.Module): ...@@ -427,6 +421,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H] # [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5) pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res] # [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1) square_mask = self.inf * (square_mask - 1)
...@@ -460,51 +455,35 @@ class InvariantPointAttention(nn.Module): ...@@ -460,51 +455,35 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2) o = flatten_final_dims(o, 2)
if self.is_multimer: # [*, H, 3, N_res, P_v]
# As DeepMind explains, this manual matmul ensures that the operation if (inplace_safe):
# happens in float32. v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
# [*, N_res, H, P_v] o_pt = [
o_pt = v_pts[..., None, :, :, :] * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1) torch.matmul(a, v.to(a.dtype))
o_pt = o_pt.sum(dim=-3) for v in torch.unbind(v_pts, dim=-3)
]
# [*, N_res, H, P_v] o_pt = torch.stack(o_pt, dim=-3)
o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
else: else:
# [*, H, 3, N_res, P_v] o_pt = torch.sum(
if (inplace_safe): (
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2)) a[..., None, :, :, None]
o_pt = [ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
torch.matmul(a, v.to(a.dtype)) ),
for v in torch.unbind(v_pts, dim=-3) dim=-2,
] )
o_pt = torch.stack(o_pt, dim=-3)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt) o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims( o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2 torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
) )
# [*, N_res, H * P_v, 3] # [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1) o_pt = torch.unbind(o_pt, dim=-1)
if (_offload_inference): if (_offload_inference):
z[0] = z[0].to(o_pt.device) z[0] = z[0].to(o_pt.device)
......
...@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import ( ...@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import ( from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing,
FusedTriangleMultiplicationIncoming
) )
from openfold.utils.checkpointing import checkpoint_blocks from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import ( from openfold.utils.chunk_utils import (
...@@ -155,6 +157,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -155,6 +157,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n: int, pair_transition_n: int,
dropout_rate: float, dropout_rate: float,
tri_mul_first: bool, tri_mul_first: bool,
fuse_projection_weights: bool,
inf: float, inf: float,
**kwargs, **kwargs,
): ):
...@@ -185,14 +188,24 @@ class TemplatePairStackBlock(nn.Module): ...@@ -185,14 +188,24 @@ class TemplatePairStackBlock(nn.Module):
inf=inf, inf=inf,
) )
self.tri_mul_out = TriangleMultiplicationOutgoing( if fuse_projection_weights:
self.c_t, self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
self.c_hidden_tri_mul, self.c_t,
) self.c_hidden_tri_mul,
self.tri_mul_in = TriangleMultiplicationIncoming( )
self.c_t, self.tri_mul_in = FusedTriangleMultiplicationIncoming(
self.c_hidden_tri_mul, self.c_t,
) self.c_hidden_tri_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
self.pair_transition = PairTransition( self.pair_transition = PairTransition(
self.c_t, self.c_t,
...@@ -329,6 +342,7 @@ class TemplatePairStack(nn.Module): ...@@ -329,6 +342,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
tri_mul_first, tri_mul_first,
fuse_projection_weights,
blocks_per_ckpt, blocks_per_ckpt,
tune_chunk_size: bool = False, tune_chunk_size: bool = False,
inf=1e9, inf=1e9,
...@@ -366,6 +380,7 @@ class TemplatePairStack(nn.Module): ...@@ -366,6 +380,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n=pair_transition_n, pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first, tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf, inf=inf,
) )
self.blocks.append(block) self.blocks.append(block)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from functools import partialmethod from functools import partialmethod
from typing import Optional from typing import Optional
from abc import ABC, abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled ...@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims from openfold.utils.tensor_utils import add, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module): class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
""" """
Implements Algorithms 11 and 12. Implements Algorithms 11 and 12.
""" """
def __init__(self, c_z, c_hidden, _outgoing=True): @abstractmethod
def __init__(self, c_z, c_hidden, _outgoing):
""" """
Args: Args:
c_z: c_z:
...@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
c: c:
Hidden channel dimension Hidden channel dimension
""" """
super(TriangleMultiplicativeUpdate, self).__init__() super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z self.c_z = c_z
self.c_hidden = c_hidden self.c_hidden = c_hidden
self._outgoing = _outgoing self._outgoing = _outgoing
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_g = Linear(self.c_z, self.c_z, init="gating") self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final") self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
...@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
return permute_final_dims(p, (1, 2, 0)) return permute_final_dims(p, (1, 2, 0))
@abstractmethod
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
def _inference_forward(self, def _inference_forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
...@@ -425,3 +463,149 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): ...@@ -425,3 +463,149 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
Implements Algorithm 12. Implements Algorithm 12.
""" """
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False) __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(FusedTriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_ab_p = Linear(self.c_z, self.c_hidden * 2)
self.linear_ab_g = Linear(self.c_z, self.c_hidden * 2, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
_inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask):
pair = self.layer_norm_in(pair)
p = self.linear_ab_g(pair)
p.sigmoid_()
p *= self.linear_ab_p(pair)
p *= mask
return p
def compute_projection(pair, mask):
p = compute_projection_helper(pair, mask)
a = p[..., :self.c_hidden]
b = p[..., self.c_hidden:]
return a, b
a, b = compute_projection(z, mask)
x = self._combine_projections(a, b, _inplace_chunk_size=_inplace_chunk_size)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g.sigmoid_()
x *= g
if (with_add):
z += x
else:
z = x
return z
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if (inplace_safe):
x = self._inference_forward(
z,
mask,
_inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
ab = mask
ab = ab * self.sigmoid(self.linear_ab_g(z))
ab = ab * self.linear_ab_p(z)
a = ab[..., :self.c_hidden]
b = ab[..., self.c_hidden:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a = a / a.std()
b = b / b.std()
if (is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
x = x * g
return x
class FusedTriangleMultiplicationOutgoing(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=True)
class FusedTriangleMultiplicationIncoming(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=False)
...@@ -189,7 +189,7 @@ def torsion_angles_to_frames( ...@@ -189,7 +189,7 @@ def torsion_angles_to_frames(
rrgdf: torch.Tensor, rrgdf: torch.Tensor,
): ):
rigid_type = Rigid if isinstance(r, Rigid) else rigid_matrix_vector.Rigid3Array rigid_type = type(r)
# [*, N, 8, 4, 4] # [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...] default_4x4 = rrgdf[aatype, ...]
...@@ -217,18 +217,14 @@ def torsion_angles_to_frames( ...@@ -217,18 +217,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses # This follows the original code rather than the supplement, which uses
# different indices. # different indices.
all_rots = alpha.new_zeros(default_r.shape + (3, 3)) all_rots = alpha.new_zeros(default_r.shape + (4, 4))
all_rots[..., 0, 0] = 1 all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1] all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:3] = alpha
if isinstance(r, Rigid): all_rots = rigid_type.from_tensor_4x4(all_rots)
all_rots = Rigid(Rotation(rot_mats=all_rots), None) all_frames = default_r.compose(all_rots)
all_frames = default_r.compose(all_rots)
else:
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
...@@ -283,16 +279,11 @@ def frames_and_literature_positions_to_atom14_pos( ...@@ -283,16 +279,11 @@ def frames_and_literature_positions_to_atom14_pos(
) )
# [*, N, 14] # [*, N, 14]
atom_mask = atom_mask[aatype, ...] atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
if isinstance(r, Rigid):
atom_mask = atom_mask.unsqueeze(-1)
# [*, N, 14, 3] # [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...] lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions) pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask pred_positions = pred_positions * atom_mask
if isinstance(pred_positions, vector.Vec3Array):
return pred_positions.to_tensor()
return pred_positions return pred_positions
...@@ -67,14 +67,17 @@ class Rigid3Array: ...@@ -67,14 +67,17 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point.""" """Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> vector.Vec3Array: def apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_to_point(vector.Vec3Array.from_array(point)) return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point.""" """Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point) return self.rotation.apply_inverse_to_point(new_point)
def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def compose_rotation(self, other_rotation): def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation rot = self.rotation @ other_rotation
return Rigid3Array(rot, self.translation.clone()) return Rigid3Array(rot, self.translation.clone())
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -191,31 +192,47 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -191,31 +192,47 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention": AttentionGatedParams(tri_att.mha), "attention": AttentionGatedParams(tri_att.mha),
} }
TriMulOutParams = lambda tri_mul: { def TriMulOutParams(tri_mul, outgoing=True):
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), if re.fullmatch("^model_[1-5]_multimer_v3$", version):
"left_projection": LinearParams(tri_mul.linear_a_p), d = {
"right_projection": LinearParams(tri_mul.linear_b_p), "left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_gate": LinearParams(tri_mul.linear_a_g), "projection": LinearParams(tri_mul.linear_ab_p),
"right_gate": LinearParams(tri_mul.linear_b_g), "gate": LinearParams(tri_mul.linear_ab_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), "center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z), }
"gating_linear": LinearParams(tri_mul.linear_g), else:
} # see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
if outgoing:
left_projection = LinearParams(tri_mul.linear_a_p)
right_projection = LinearParams(tri_mul.linear_b_p)
left_gate = LinearParams(tri_mul.linear_a_g)
right_gate = LinearParams(tri_mul.linear_b_g)
else:
left_projection = LinearParams(tri_mul.linear_b_p)
right_projection = LinearParams(tri_mul.linear_a_p)
left_gate = LinearParams(tri_mul.linear_b_g)
right_gate = LinearParams(tri_mul.linear_a_g)
d = {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": left_projection,
"right_projection": right_projection,
"left_gate": left_gate,
"right_gate": right_gate,
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
}
# see commit b88f8da on the Alphafold repo d.update({
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming "output_projection": LinearParams(tri_mul.linear_z),
# iterations of triangle multiplication, which is confusing and not "gating_linear": LinearParams(tri_mul.linear_g),
# reproduced in our implementation. })
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), return d
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p), TriMulInParams = partial(TriMulOutParams, outgoing=False)
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
PairTransitionParams = lambda pt: { PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm), "input_layer_norm": LayerNormParams(pt.layer_norm),
......
...@@ -56,16 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}" ...@@ -56,16 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..." echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniclust30..." echo "Downloading Uniref30..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniref30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..." echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..." echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded." echo "All data downloaded."
...@@ -31,7 +31,7 @@ fi ...@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params" ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar" SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
...@@ -32,8 +32,8 @@ fi ...@@ -32,8 +32,8 @@ fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/mgnify" ROOT_DIR="${DOWNLOAD_DIR}/mgnify"
# Mirror of: # Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz # ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz" SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/mgy_clusters_2022_05.fa.gz"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
...@@ -36,3 +36,7 @@ BASENAME=$(basename "${SOURCE_URL}") ...@@ -36,3 +36,7 @@ BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
# Keep only protein sequences.
grep --after-context=1 --no-group-separator '>.* mol:protein' "${ROOT_DIR}/pdb_seqres.txt" > "${ROOT_DIR}/pdb_seqres_filtered.txt"
mv "${ROOT_DIR}/pdb_seqres_filtered.txt" "${ROOT_DIR}/pdb_seqres.txt"
...@@ -30,8 +30,10 @@ if ! command -v aria2c &> /dev/null ; then ...@@ -30,8 +30,10 @@ if ! command -v aria2c &> /dev/null ; then
fi fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}" ROOT_DIR="${DOWNLOAD_DIR}/uniref30"
SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz" # Mirror of:
# https://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/UniRef30_2021_03.tar.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/UniRef30_2021_03.tar.gz"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
...@@ -2,7 +2,7 @@ import ml_collections as mlc ...@@ -2,7 +2,7 @@ import ml_collections as mlc
consts = mlc.ConfigDict( consts = mlc.ConfigDict(
{ {
"model": "model_1_multimer_v2", # monomer:model_1_ptm, multimer: model_1_multimer_v2 "model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True "is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4, "chunk_size": 4,
"batch_size": 2, "batch_size": 2,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
...@@ -49,6 +50,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -49,6 +50,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout = 0.15 msa_dropout = 0.15
pair_stack_dropout = 0.25 pair_stack_dropout = 0.25
opm_first = consts.is_multimer opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9 inf = 1e9
eps = 1e-10 eps = 1e-10
...@@ -67,6 +69,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -67,6 +69,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
opm_first, opm_first,
fuse_projection_weights,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -177,6 +180,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -177,6 +180,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout = 0.15 msa_dropout = 0.15
pair_stack_dropout = 0.25 pair_stack_dropout = 0.25
opm_first = consts.is_multimer opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9 inf = 1e9
eps = 1e-10 eps = 1e-10
...@@ -194,6 +198,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -194,6 +198,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
opm_first, opm_first,
fuse_projection_weights,
ckpt=False, ckpt=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
...@@ -78,6 +79,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -78,6 +79,7 @@ class TestTemplatePairStack(unittest.TestCase):
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res
tri_mul_first = consts.is_multimer tri_mul_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
blocks_per_ckpt = None blocks_per_ckpt = None
chunk_size = 4 chunk_size = 4
inf = 1e7 inf = 1e7
...@@ -92,6 +94,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -92,6 +94,7 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n=pt_inner_dim, pair_transition_n=pt_inner_dim,
dropout_rate=dropout, dropout_rate=dropout,
tri_mul_first=tri_mul_first, tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import re
import numpy as np import numpy as np
import unittest import unittest
from openfold.model.triangular_multiplicative_update import * from openfold.model.triangular_multiplicative_update import *
...@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z = consts.c_z c_z = consts.c_z
c = 11 c = 11
tm = TriangleMultiplicationOutgoing( if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
c_z, tm = FusedTriangleMultiplicationOutgoing(
c, c_z,
) c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
n_res = consts.c_z n_res = consts.c_z
batch_size = consts.batch_size batch_size = consts.batch_size
...@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config.model.global_config, config.model.global_config,
name=name, name=name,
) )
act = tri_mul(act=pair_act, mask=pair_mask) act = tri_mul(pair_act, pair_mask)
return act return act
f = hk.transform(run_tri_mul) f = hk.transform(run_tri_mul)
...@@ -89,6 +96,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -89,6 +96,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if incoming if incoming
else model.evoformer.blocks[0].pair_stack.tri_mul_out else model.evoformer.blocks[0].pair_stack.tri_mul_out
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment