Commit 68828c49 authored by Christina Floristean's avatar Christina Floristean
Browse files

Multimer v3 updates

parent 736f27fd
import re
import copy
import importlib
import ml_collections as mlc
......@@ -155,6 +156,18 @@ def model_config(
elif "multimer" in name:
c.globals.is_multimer = True
c.loss.masked_msa.num_classes = 22
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
c.model.evoformer.num_msa = 252
c.model.evoformer.num_extra_msa= 1152
c.model.evoformer.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
elif name == 'model_5_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
for k,v in multimer_model_config_update.items():
c.model[k] = v
......@@ -438,6 +451,7 @@ config = mlc.ConfigDict(
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
......@@ -487,6 +501,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"fuse_projection_weights": False,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
......@@ -510,6 +525,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
......@@ -671,6 +687,7 @@ multimer_model_config_update = {
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
......@@ -701,6 +718,7 @@ multimer_model_config_update = {
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
......@@ -723,6 +741,7 @@ multimer_model_config_update = {
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
......
This diff is collapsed.
......@@ -178,7 +178,7 @@ class PointProjection(nn.Module):
def forward(self,
activations: torch.Tensor,
rigids: Union[Rigid, Rigid3Array],
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array], torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
......@@ -398,20 +398,14 @@ class InvariantPointAttention(nn.Module):
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
if self.is_multimer:
pt_att = q_pts.unsqueeze(-3) - k_pts.unsqueeze(-4)
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
# [*, N_res, N_res, H, P_q]
pt_att = sum([c ** 2 for c in pt_att])
if (inplace_safe):
pt_att *= pt_att
else:
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
if (inplace_safe):
pt_att *= pt_att
else:
pt_att = pt_att ** 2
pt_att = sum(torch.unbind(pt_att, dim=-1))
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
......@@ -427,6 +421,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
......@@ -460,51 +455,35 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
if self.is_multimer:
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, N_res, H, P_v]
o_pt = v_pts[..., None, :, :, :] * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v]
o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
# [*, H, 3, N_res, P_v]
if (inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
for v in torch.unbind(v_pts, dim=-3)
]
o_pt = torch.stack(o_pt, dim=-3)
else:
# [*, H, 3, N_res, P_v]
if (inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
for v in torch.unbind(v_pts, dim=-3)
]
o_pt = torch.stack(o_pt, dim=-3)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
if (_offload_inference):
z[0] = z[0].to(o_pt.device)
......
......@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing,
FusedTriangleMultiplicationIncoming
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import (
......@@ -155,6 +157,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n: int,
dropout_rate: float,
tri_mul_first: bool,
fuse_projection_weights: bool,
inf: float,
**kwargs,
):
......@@ -185,14 +188,24 @@ class TemplatePairStackBlock(nn.Module):
inf=inf,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
if fuse_projection_weights:
self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = FusedTriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
self.pair_transition = PairTransition(
self.c_t,
......@@ -329,6 +342,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n,
dropout_rate,
tri_mul_first,
fuse_projection_weights,
blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9,
......@@ -366,6 +380,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
)
self.blocks.append(block)
......
......@@ -15,6 +15,7 @@
from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
......@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
@abstractmethod
def __init__(self, c_z, c_hidden, _outgoing):
"""
Args:
c_z:
......@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__()
super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
......@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
return permute_final_dims(p, (1, 2, 0))
@abstractmethod
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
......@@ -425,3 +463,149 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
Implements Algorithm 12.
"""
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(FusedTriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_ab_p = Linear(self.c_z, self.c_hidden * 2)
self.linear_ab_g = Linear(self.c_z, self.c_hidden * 2, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
_inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask):
pair = self.layer_norm_in(pair)
p = self.linear_ab_g(pair)
p.sigmoid_()
p *= self.linear_ab_p(pair)
p *= mask
return p
def compute_projection(pair, mask):
p = compute_projection_helper(pair, mask)
a = p[..., :self.c_hidden]
b = p[..., self.c_hidden:]
return a, b
a, b = compute_projection(z, mask)
x = self._combine_projections(a, b, _inplace_chunk_size=_inplace_chunk_size)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g.sigmoid_()
x *= g
if (with_add):
z += x
else:
z = x
return z
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if (inplace_safe):
x = self._inference_forward(
z,
mask,
_inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
ab = mask
ab = ab * self.sigmoid(self.linear_ab_g(z))
ab = ab * self.linear_ab_p(z)
a = ab[..., :self.c_hidden]
b = ab[..., self.c_hidden:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a = a / a.std()
b = b / b.std()
if (is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
x = x * g
return x
class FusedTriangleMultiplicationOutgoing(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=True)
class FusedTriangleMultiplicationIncoming(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=False)
......@@ -189,7 +189,7 @@ def torsion_angles_to_frames(
rrgdf: torch.Tensor,
):
rigid_type = Rigid if isinstance(r, Rigid) else rigid_matrix_vector.Rigid3Array
rigid_type = type(r)
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
......@@ -217,18 +217,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.shape + (3, 3))
all_rots = alpha.new_zeros(default_r.shape + (4, 4))
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots[..., 2, 1:3] = alpha
if isinstance(r, Rigid):
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
else:
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
all_rots = rigid_type.from_tensor_4x4(all_rots)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
......@@ -283,16 +279,11 @@ def frames_and_literature_positions_to_atom14_pos(
)
# [*, N, 14]
atom_mask = atom_mask[aatype, ...]
if isinstance(r, Rigid):
atom_mask = atom_mask.unsqueeze(-1)
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
if isinstance(pred_positions, vector.Vec3Array):
return pred_positions.to_tensor()
return pred_positions
......@@ -67,14 +67,17 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> vector.Vec3Array:
return self.apply_to_point(vector.Vec3Array.from_array(point))
def apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, self.translation.clone())
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from enum import Enum
from dataclasses import dataclass
from functools import partial
......@@ -191,31 +192,47 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention": AttentionGatedParams(tri_att.mha),
}
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p),
"right_projection": LinearParams(tri_mul.linear_b_p),
"left_gate": LinearParams(tri_mul.linear_a_g),
"right_gate": LinearParams(tri_mul.linear_b_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
def TriMulOutParams(tri_mul, outgoing=True):
if re.fullmatch("^model_[1-5]_multimer_v3$", version):
d = {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_ab_p),
"gate": LinearParams(tri_mul.linear_ab_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
}
else:
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
if outgoing:
left_projection = LinearParams(tri_mul.linear_a_p)
right_projection = LinearParams(tri_mul.linear_b_p)
left_gate = LinearParams(tri_mul.linear_a_g)
right_gate = LinearParams(tri_mul.linear_b_g)
else:
left_projection = LinearParams(tri_mul.linear_b_p)
right_projection = LinearParams(tri_mul.linear_a_p)
left_gate = LinearParams(tri_mul.linear_b_g)
right_gate = LinearParams(tri_mul.linear_a_g)
d = {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": left_projection,
"right_projection": right_projection,
"left_gate": left_gate,
"right_gate": right_gate,
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p),
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
d.update({
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
})
return d
TriMulInParams = partial(TriMulOutParams, outgoing=False)
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
......
......@@ -56,16 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniclust30..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref30..."
bash "${SCRIPT_DIR}/download_uniref30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded."
......@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
......
......@@ -32,8 +32,8 @@ fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/mgnify"
# Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz"
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/mgy_clusters_2022_05.fa.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
......
......@@ -36,3 +36,7 @@ BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
# Keep only protein sequences.
grep --after-context=1 --no-group-separator '>.* mol:protein' "${ROOT_DIR}/pdb_seqres.txt" > "${ROOT_DIR}/pdb_seqres_filtered.txt"
mv "${ROOT_DIR}/pdb_seqres_filtered.txt" "${ROOT_DIR}/pdb_seqres.txt"
......@@ -30,8 +30,10 @@ if ! command -v aria2c &> /dev/null ; then
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}"
SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz"
ROOT_DIR="${DOWNLOAD_DIR}/uniref30"
# Mirror of:
# https://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/UniRef30_2021_03.tar.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/UniRef30_2021_03.tar.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
......
......@@ -2,7 +2,7 @@ import ml_collections as mlc
consts = mlc.ConfigDict(
{
"model": "model_1_multimer_v2", # monomer:model_1_ptm, multimer: model_1_multimer_v2
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -49,6 +50,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -67,6 +69,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout,
pair_stack_dropout,
opm_first,
fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......@@ -177,6 +180,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -194,6 +198,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout,
pair_stack_dropout,
opm_first,
fuse_projection_weights,
ckpt=False,
inf=inf,
eps=eps,
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -78,6 +79,7 @@ class TestTemplatePairStack(unittest.TestCase):
n_templ = consts.n_templ
n_res = consts.n_res
tri_mul_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
blocks_per_ckpt = None
chunk_size = 4
inf = 1e7
......@@ -92,6 +94,7 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import torch
import re
import numpy as np
import unittest
from openfold.model.triangular_multiplicative_update import *
......@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z = consts.c_z
c = 11
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
tm = FusedTriangleMultiplicationOutgoing(
c_z,
c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
n_res = consts.c_z
batch_size = consts.batch_size
......@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config.model.global_config,
name=name,
)
act = tri_mul(act=pair_act, mask=pair_mask)
act = tri_mul(pair_act, pair_mask)
return act
f = hk.transform(run_tri_mul)
......@@ -89,6 +96,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if incoming
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment