Commit b026de28 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make chunk size more flexible, reduce verbosity

parent a59ae7c1
...@@ -45,13 +45,12 @@ def model_config(name, train=False, low_prec=False): ...@@ -45,13 +45,12 @@ def model_config(name, train=False, low_prec=False):
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
if low_prec: if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be # If we want exact numerical parity with the original, inf can't be
# a global constant # a global constant
set_inf(c, 1e4) set_inf(c, 1e5)
return c return c
...@@ -225,7 +224,8 @@ config = mlc.ConfigDict( ...@@ -225,7 +224,8 @@ config = mlc.ConfigDict(
# Recurring FieldReferences that can be changed globally here # Recurring FieldReferences that can be changed globally here
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "train_chunk_size": None,
"eval_chunk_size": chunk_size,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
...@@ -277,8 +277,7 @@ config = mlc.ConfigDict( ...@@ -277,8 +277,7 @@ config = mlc.ConfigDict(
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "inf": 1e9,
"inf": 1e5, # 1e9,
}, },
"template_pointwise_attention": { "template_pointwise_attention": {
"c_t": c_t, "c_t": c_t,
...@@ -287,7 +286,6 @@ config = mlc.ConfigDict( ...@@ -287,7 +286,6 @@ config = mlc.ConfigDict(
# It's actually 16. # It's actually 16.
"c_hidden": 16, "c_hidden": 16,
"no_heads": 4, "no_heads": 4,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9, "inf": 1e5, # 1e9,
}, },
"inf": 1e5, # 1e9, "inf": 1e5, # 1e9,
...@@ -314,8 +312,7 @@ config = mlc.ConfigDict( ...@@ -314,8 +312,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "inf": 1e9,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
"enabled": True, "enabled": True,
...@@ -335,8 +332,7 @@ config = mlc.ConfigDict( ...@@ -335,8 +332,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "inf": 1e9,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
"structure_module": { "structure_module": {
......
...@@ -165,7 +165,7 @@ def generate_release_dates_cache(mmcif_dir: str, out_path: str): ...@@ -165,7 +165,7 @@ def generate_release_dates_cache(mmcif_dir: str, out_path: str):
file_id=file_id, mmcif_string=mmcif_string file_id=file_id, mmcif_string=mmcif_string
) )
if mmcif.mmcif_object is None: if mmcif.mmcif_object is None:
logging.warning(f"Failed to parse {f}. Skipping...") logging.info(f"Failed to parse {f}. Skipping...")
continue continue
mmcif = mmcif.mmcif_object mmcif = mmcif.mmcif_object
...@@ -822,7 +822,7 @@ def _process_single_hit( ...@@ -822,7 +822,7 @@ def _process_single_hit(
if strict_error_check: if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
else: else:
logging.warning(error) logging.info(error)
return SingleHitResult(features=None, error=None, warning=None) return SingleHitResult(features=None, error=None, warning=None)
try: try:
......
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import subprocess import subprocess
from typing import Sequence from typing import Sequence
from openfold.data.np import utils from openfold.data.tools import utils
class HHSearch: class HHSearch:
......
...@@ -46,7 +46,7 @@ class MSATransition(nn.Module): ...@@ -46,7 +46,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9 Implements Algorithm 9
""" """
def __init__(self, c_m, n, chunk_size): def __init__(self, c_m, n):
""" """
Args: Args:
c_m: c_m:
...@@ -59,7 +59,6 @@ class MSATransition(nn.Module): ...@@ -59,7 +59,6 @@ class MSATransition(nn.Module):
self.c_m = c_m self.c_m = c_m
self.n = n self.n = n
self.chunk_size = chunk_size
self.layer_norm = nn.LayerNorm(self.c_m) self.layer_norm = nn.LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
...@@ -76,6 +75,7 @@ class MSATransition(nn.Module): ...@@ -76,6 +75,7 @@ class MSATransition(nn.Module):
self, self,
m: torch.Tensor, m: torch.Tensor,
mask: torch.Tensor = None, mask: torch.Tensor = None,
chunk_size: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -96,11 +96,11 @@ class MSATransition(nn.Module): ...@@ -96,11 +96,11 @@ class MSATransition(nn.Module):
m = self.layer_norm(m) m = self.layer_norm(m)
inp = {"m": m, "mask": mask} inp = {"m": m, "mask": mask}
if self.chunk_size is not None: if chunk_size is not None:
m = chunk_layer( m = chunk_layer(
self._transition, self._transition,
inp, inp,
chunk_size=self.chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
else: else:
...@@ -123,7 +123,6 @@ class EvoformerBlock(nn.Module): ...@@ -123,7 +123,6 @@ class EvoformerBlock(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
chunk_size: int,
inf: float, inf: float,
eps: float, eps: float,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
...@@ -135,7 +134,6 @@ class EvoformerBlock(nn.Module): ...@@ -135,7 +134,6 @@ class EvoformerBlock(nn.Module):
c_z=c_z, c_z=c_z,
c_hidden=c_hidden_msa_att, c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa, no_heads=no_heads_msa,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
...@@ -144,7 +142,6 @@ class EvoformerBlock(nn.Module): ...@@ -144,7 +142,6 @@ class EvoformerBlock(nn.Module):
c_in=c_m, c_in=c_m,
c_hidden=c_hidden_msa_att, c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa, no_heads=no_heads_msa,
chunk_size=chunk_size,
inf=inf, inf=inf,
eps=eps, eps=eps,
) )
...@@ -153,21 +150,18 @@ class EvoformerBlock(nn.Module): ...@@ -153,21 +150,18 @@ class EvoformerBlock(nn.Module):
c_m, c_m,
c_hidden_msa_att, c_hidden_msa_att,
no_heads_msa, no_heads_msa,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
self.msa_transition = MSATransition( self.msa_transition = MSATransition(
c_m=c_m, c_m=c_m,
n=transition_n, n=transition_n,
chunk_size=chunk_size,
) )
self.outer_product_mean = OuterProductMean( self.outer_product_mean = OuterProductMean(
c_m, c_m,
c_z, c_z,
c_hidden_opm, c_hidden_opm,
chunk_size=chunk_size,
) )
self.tri_mul_out = TriangleMultiplicationOutgoing( self.tri_mul_out = TriangleMultiplicationOutgoing(
...@@ -183,21 +177,18 @@ class EvoformerBlock(nn.Module): ...@@ -183,21 +177,18 @@ class EvoformerBlock(nn.Module):
c_z, c_z,
c_hidden_pair_att, c_hidden_pair_att,
no_heads_pair, no_heads_pair,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
self.tri_att_end = TriangleAttentionEndingNode( self.tri_att_end = TriangleAttentionEndingNode(
c_z, c_z,
c_hidden_pair_att, c_hidden_pair_att,
no_heads_pair, no_heads_pair,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
self.pair_transition = PairTransition( self.pair_transition = PairTransition(
c_z, c_z,
transition_n, transition_n,
chunk_size=chunk_size,
) )
self.msa_dropout_layer = DropoutRowwise(msa_dropout) self.msa_dropout_layer = DropoutRowwise(msa_dropout)
...@@ -210,6 +201,7 @@ class EvoformerBlock(nn.Module): ...@@ -210,6 +201,7 @@ class EvoformerBlock(nn.Module):
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
...@@ -218,15 +210,27 @@ class EvoformerBlock(nn.Module): ...@@ -218,15 +210,27 @@ class EvoformerBlock(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_dropout_layer(self.msa_att_row(m, z, mask=msa_mask)) m = m + self.msa_dropout_layer(
m = m + self.msa_att_col(m, mask=msa_mask) self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition(m, mask=msa_trans_mask) )
z = z + self.outer_product_mean(m, mask=msa_mask) m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_att_start(z, mask=pair_mask)) z = z + self.ps_dropout_row_layer(
z = z + self.ps_dropout_col_layer(self.tri_att_end(z, mask=pair_mask)) self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
z = z + self.pair_transition(z, mask=pair_trans_mask) )
z = z + self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
)
z = z + self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size
)
return m, z return m, z
...@@ -254,7 +258,6 @@ class EvoformerStack(nn.Module): ...@@ -254,7 +258,6 @@ class EvoformerStack(nn.Module):
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
blocks_per_ckpt: int, blocks_per_ckpt: int,
chunk_size: int,
inf: float, inf: float,
eps: float, eps: float,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
...@@ -312,7 +315,6 @@ class EvoformerStack(nn.Module): ...@@ -312,7 +315,6 @@ class EvoformerStack(nn.Module):
transition_n=transition_n, transition_n=transition_n,
msa_dropout=msa_dropout, msa_dropout=msa_dropout,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
chunk_size=chunk_size,
inf=inf, inf=inf,
eps=eps, eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack, _is_extra_msa_stack=_is_extra_msa_stack,
...@@ -328,6 +330,7 @@ class EvoformerStack(nn.Module): ...@@ -328,6 +330,7 @@ class EvoformerStack(nn.Module):
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" """
...@@ -354,6 +357,7 @@ class EvoformerStack(nn.Module): ...@@ -354,6 +357,7 @@ class EvoformerStack(nn.Module):
b, b,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
for b in self.blocks for b in self.blocks
...@@ -392,7 +396,6 @@ class ExtraMSAStack(nn.Module): ...@@ -392,7 +396,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
blocks_per_ckpt: int, blocks_per_ckpt: int,
chunk_size: int,
inf: float, inf: float,
eps: float, eps: float,
**kwargs, **kwargs,
...@@ -415,7 +418,6 @@ class ExtraMSAStack(nn.Module): ...@@ -415,7 +418,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout=msa_dropout, msa_dropout=msa_dropout,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
blocks_per_ckpt=blocks_per_ckpt, blocks_per_ckpt=blocks_per_ckpt,
chunk_size=chunk_size,
inf=inf, inf=inf,
eps=eps, eps=eps,
_is_extra_msa_stack=True, _is_extra_msa_stack=True,
...@@ -425,6 +427,7 @@ class ExtraMSAStack(nn.Module): ...@@ -425,6 +427,7 @@ class ExtraMSAStack(nn.Module):
self, self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -447,6 +450,7 @@ class ExtraMSAStack(nn.Module): ...@@ -447,6 +450,7 @@ class ExtraMSAStack(nn.Module):
z, z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
return z return z
...@@ -63,6 +63,8 @@ class AlphaFold(nn.Module): ...@@ -63,6 +63,8 @@ class AlphaFold(nn.Module):
""" """
super(AlphaFold, self).__init__() super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template template_config = config.template
extra_msa_config = config.extra_msa extra_msa_config = config.extra_msa
...@@ -104,7 +106,7 @@ class AlphaFold(nn.Module): ...@@ -104,7 +106,7 @@ class AlphaFold(nn.Module):
self.config = config self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim): def embed_templates(self, batch, z, pair_mask, templ_dim, chunk_size):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
...@@ -135,14 +137,13 @@ class AlphaFold(nn.Module): ...@@ -135,14 +137,13 @@ class AlphaFold(nn.Module):
) )
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
t = self.template_pair_stack( t = self.template_pair_stack(
t, pair_mask.unsqueeze(-3), _mask_trans=self.config._mask_trans t,
pair_mask.unsqueeze(-3),
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans,
) )
single_template_embeds.update( single_template_embeds.update({"pair": t})
{
"pair": t,
}
)
template_embeds.append(single_template_embeds) template_embeds.append(single_template_embeds)
...@@ -153,7 +154,10 @@ class AlphaFold(nn.Module): ...@@ -153,7 +154,10 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
template_embeds["pair"], z, template_mask=batch["template_mask"] template_embeds["pair"],
z,
template_mask=batch["template_mask"],
chunk_size=chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
...@@ -161,15 +165,17 @@ class AlphaFold(nn.Module): ...@@ -161,15 +165,17 @@ class AlphaFold(nn.Module):
if self.config.template.embed_angles: if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"] ret["template_angle_embedding"] = template_embeds["angle"]
ret.update( ret.update({"template_pair_embedding": t})
{
"template_pair_embedding": t,
}
)
return ret return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev): def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Establish constants
chunk_size = (
self.globals.train_chunk_size
if self.training else self.globals.eval_chunk_size
)
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
...@@ -243,6 +249,7 @@ class AlphaFold(nn.Module): ...@@ -243,6 +249,7 @@ class AlphaFold(nn.Module):
z, z,
pair_mask, pair_mask,
no_batch_dims, no_batch_dims,
chunk_size,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
...@@ -270,6 +277,7 @@ class AlphaFold(nn.Module): ...@@ -270,6 +277,7 @@ class AlphaFold(nn.Module):
a, a,
z, z,
msa_mask=feats["extra_msa_mask"], msa_mask=feats["extra_msa_mask"],
chunk_size=chunk_size,
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -283,6 +291,7 @@ class AlphaFold(nn.Module): ...@@ -283,6 +291,7 @@ class AlphaFold(nn.Module):
z, z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
......
...@@ -34,7 +34,6 @@ class MSAAttention(nn.Module): ...@@ -34,7 +34,6 @@ class MSAAttention(nn.Module):
no_heads, no_heads,
pair_bias=False, pair_bias=False,
c_z=None, c_z=None,
chunk_size=4,
inf=1e9, inf=1e9,
): ):
""" """
...@@ -60,7 +59,6 @@ class MSAAttention(nn.Module): ...@@ -60,7 +59,6 @@ class MSAAttention(nn.Module):
self.no_heads = no_heads self.no_heads = no_heads
self.pair_bias = pair_bias self.pair_bias = pair_bias
self.c_z = c_z self.c_z = c_z
self.chunk_size = chunk_size
self.inf = inf self.inf = inf
self.layer_norm_m = nn.LayerNorm(self.c_in) self.layer_norm_m = nn.LayerNorm(self.c_in)
...@@ -75,7 +73,7 @@ class MSAAttention(nn.Module): ...@@ -75,7 +73,7 @@ class MSAAttention(nn.Module):
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
) )
def forward(self, m, z=None, mask=None): def forward(self, m, chunk_size, z=None, mask=None):
""" """
Args: Args:
m: m:
...@@ -117,11 +115,11 @@ class MSAAttention(nn.Module): ...@@ -117,11 +115,11 @@ class MSAAttention(nn.Module):
biases.append(z) biases.append(z)
mha_inputs = {"q_x": m, "k_x": m, "v_x": m, "biases": biases} mha_inputs = {"q_x": m, "k_x": m, "v_x": m, "biases": biases}
if self.chunk_size is not None: if chunk_size is not None:
m = chunk_layer( m = chunk_layer(
self.mha, self.mha,
mha_inputs, mha_inputs,
chunk_size=self.chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
else: else:
...@@ -135,7 +133,7 @@ class MSARowAttentionWithPairBias(MSAAttention): ...@@ -135,7 +133,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
Implements Algorithm 7. Implements Algorithm 7.
""" """
def __init__(self, c_m, c_z, c_hidden, no_heads, chunk_size, inf=1e9): def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
""" """
Args: Args:
c_m: c_m:
...@@ -155,7 +153,6 @@ class MSARowAttentionWithPairBias(MSAAttention): ...@@ -155,7 +153,6 @@ class MSARowAttentionWithPairBias(MSAAttention):
no_heads, no_heads,
pair_bias=True, pair_bias=True,
c_z=c_z, c_z=c_z,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
...@@ -165,7 +162,7 @@ class MSAColumnAttention(MSAAttention): ...@@ -165,7 +162,7 @@ class MSAColumnAttention(MSAAttention):
Implements Algorithm 8. Implements Algorithm 8.
""" """
def __init__(self, c_m, c_hidden, no_heads, chunk_size=4, inf=1e9): def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
""" """
Args: Args:
c_m: c_m:
...@@ -183,11 +180,10 @@ class MSAColumnAttention(MSAAttention): ...@@ -183,11 +180,10 @@ class MSAColumnAttention(MSAAttention):
no_heads=no_heads, no_heads=no_heads,
pair_bias=False, pair_bias=False,
c_z=None, c_z=None,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
def forward(self, m, mask=None): def forward(self, m, chunk_size, mask=None):
""" """
Args: Args:
m: m:
...@@ -200,7 +196,7 @@ class MSAColumnAttention(MSAAttention): ...@@ -200,7 +196,7 @@ class MSAColumnAttention(MSAAttention):
if mask is not None: if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
m = super().forward(m, mask=mask) m = super().forward(m, chunk_size=chunk_size, mask=mask)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
...@@ -211,14 +207,13 @@ class MSAColumnAttention(MSAAttention): ...@@ -211,14 +207,13 @@ class MSAColumnAttention(MSAAttention):
class MSAColumnGlobalAttention(nn.Module): class MSAColumnGlobalAttention(nn.Module):
def __init__( def __init__(
self, c_in, c_hidden, no_heads, chunk_size=4, inf=1e9, eps=1e-10 self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10
): ):
super(MSAColumnGlobalAttention, self).__init__() super(MSAColumnGlobalAttention, self).__init__()
self.c_in = c_in self.c_in = c_in
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.no_heads = no_heads self.no_heads = no_heads
self.chunk_size = chunk_size
self.inf = inf self.inf = inf
self.eps = eps self.eps = eps
...@@ -233,7 +228,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -233,7 +228,7 @@ class MSAColumnGlobalAttention(nn.Module):
) )
def forward( def forward(
self, m: torch.Tensor, mask: Optional[torch.Tensor] = None self, m: torch.Tensor, chunk_size, mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
...@@ -256,11 +251,11 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -256,11 +251,11 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m, "m": m,
"mask": mask, "mask": mask,
} }
if self.chunk_size is not None: if chunk_size is not None:
m = chunk_layer( m = chunk_layer(
self.global_attention, self.global_attention,
mha_input, mha_input,
chunk_size=self.chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
else: else:
......
...@@ -26,7 +26,7 @@ class OuterProductMean(nn.Module): ...@@ -26,7 +26,7 @@ class OuterProductMean(nn.Module):
Implements Algorithm 10. Implements Algorithm 10.
""" """
def __init__(self, c_m, c_z, c_hidden, chunk_size=4, eps=1e-3): def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
""" """
Args: Args:
c_m: c_m:
...@@ -40,7 +40,6 @@ class OuterProductMean(nn.Module): ...@@ -40,7 +40,6 @@ class OuterProductMean(nn.Module):
self.c_z = c_z self.c_z = c_z
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.chunk_size = chunk_size
self.eps = eps self.eps = eps
self.layer_norm = nn.LayerNorm(c_m) self.layer_norm = nn.LayerNorm(c_m)
...@@ -60,7 +59,7 @@ class OuterProductMean(nn.Module): ...@@ -60,7 +59,7 @@ class OuterProductMean(nn.Module):
return outer return outer
def forward(self, m, mask=None): def forward(self, m, chunk_size, mask=None):
""" """
Args: Args:
m: m:
...@@ -84,7 +83,7 @@ class OuterProductMean(nn.Module): ...@@ -84,7 +83,7 @@ class OuterProductMean(nn.Module):
a = a.transpose(-2, -3) a = a.transpose(-2, -3)
b = b.transpose(-2, -3) b = b.transpose(-2, -3)
if self.chunk_size is not None: if chunk_size is not None:
# Since the "batch dim" in this case is not a true batch dimension # Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to # (in that the shape of the output depends on it), we need to
# iterate over it ourselves # iterate over it ourselves
...@@ -95,7 +94,7 @@ class OuterProductMean(nn.Module): ...@@ -95,7 +94,7 @@ class OuterProductMean(nn.Module):
outer = chunk_layer( outer = chunk_layer(
partial(self._opm, b=b_prime), partial(self._opm, b=b_prime),
{"a": a_prime}, {"a": a_prime},
chunk_size=self.chunk_size, chunk_size=chunk_size,
no_batch_dims=1, no_batch_dims=1,
) )
out.append(outer) out.append(outer)
......
...@@ -25,7 +25,7 @@ class PairTransition(nn.Module): ...@@ -25,7 +25,7 @@ class PairTransition(nn.Module):
Implements Algorithm 15. Implements Algorithm 15.
""" """
def __init__(self, c_z, n, chunk_size=4): def __init__(self, c_z, n):
""" """
Args: Args:
c_z: c_z:
...@@ -38,7 +38,6 @@ class PairTransition(nn.Module): ...@@ -38,7 +38,6 @@ class PairTransition(nn.Module):
self.c_z = c_z self.c_z = c_z
self.n = n self.n = n
self.chunk_size = chunk_size
self.layer_norm = nn.LayerNorm(self.c_z) self.layer_norm = nn.LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
...@@ -55,7 +54,7 @@ class PairTransition(nn.Module): ...@@ -55,7 +54,7 @@ class PairTransition(nn.Module):
return z return z
def forward(self, z, mask=None): def forward(self, z, chunk_size, mask=None):
""" """
Args: Args:
z: z:
...@@ -74,11 +73,11 @@ class PairTransition(nn.Module): ...@@ -74,11 +73,11 @@ class PairTransition(nn.Module):
z = self.layer_norm(z) z = self.layer_norm(z)
inp = {"z": z, "mask": mask} inp = {"z": z, "mask": mask}
if self.chunk_size is not None: if chunk_size is not None:
z = chunk_layer( z = chunk_layer(
self._transition, self._transition,
inp, inp,
chunk_size=self.chunk_size, chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]), no_batch_dims=len(z.shape[:-2]),
) )
else: else:
......
...@@ -260,11 +260,14 @@ class Attention(nn.Module): ...@@ -260,11 +260,14 @@ class Attention(nn.Module):
# [*, H, Q, K] # [*, H, Q, K]
a = torch.matmul( a = torch.matmul(
permute_final_dims(q, (0, 2, 1, 3)), # [*, H, Q, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, Q, C_hidden]
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, K]
) )
del q, k
norm = 1 / math.sqrt(self.c_hidden) # [1] norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm a *= norm
if biases is not None: if biases is not None:
for b in biases: for b in biases:
a = a + b a = a + b
...@@ -273,7 +276,7 @@ class Attention(nn.Module): ...@@ -273,7 +276,7 @@ class Attention(nn.Module):
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul( o = torch.matmul(
a, a,
permute_final_dims(v, (0, 2, 1, 3)), # [*, H, V, C_hidden] permute_final_dims(v, (1, 0, 2)), # [*, H, V, C_hidden]
) )
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
......
...@@ -45,7 +45,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -45,7 +45,7 @@ class TemplatePointwiseAttention(nn.Module):
Implements Algorithm 17. Implements Algorithm 17.
""" """
def __init__(self, c_t, c_z, c_hidden, no_heads, chunk_size, inf, **kwargs): def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
""" """
Args: Args:
c_t: c_t:
...@@ -61,7 +61,6 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -61,7 +61,6 @@ class TemplatePointwiseAttention(nn.Module):
self.c_z = c_z self.c_z = c_z
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.no_heads = no_heads self.no_heads = no_heads
self.chunk_size = chunk_size
self.inf = inf self.inf = inf
self.mha = Attention( self.mha = Attention(
...@@ -73,7 +72,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -73,7 +72,7 @@ class TemplatePointwiseAttention(nn.Module):
gating=False, gating=False,
) )
def forward(self, t, z, template_mask=None): def forward(self, t, z, chunk_size, template_mask=None):
""" """
Args: Args:
t: t:
...@@ -106,11 +105,11 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -106,11 +105,11 @@ class TemplatePointwiseAttention(nn.Module):
"v_x": t, "v_x": t,
"biases": [bias], "biases": [bias],
} }
if self.chunk_size is not None: if chunk_size is not None:
z = chunk_layer( z = chunk_layer(
self.mha, self.mha,
mha_inputs, mha_inputs,
chunk_size=self.chunk_size, chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]), no_batch_dims=len(z.shape[:-2]),
) )
else: else:
...@@ -131,7 +130,6 @@ class TemplatePairStackBlock(nn.Module): ...@@ -131,7 +130,6 @@ class TemplatePairStackBlock(nn.Module):
no_heads, no_heads,
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
chunk_size,
inf, inf,
**kwargs, **kwargs,
): ):
...@@ -143,7 +141,6 @@ class TemplatePairStackBlock(nn.Module): ...@@ -143,7 +141,6 @@ class TemplatePairStackBlock(nn.Module):
self.no_heads = no_heads self.no_heads = no_heads
self.pair_transition_n = pair_transition_n self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.chunk_size = chunk_size
self.inf = inf self.inf = inf
self.dropout_row = DropoutRowwise(self.dropout_rate) self.dropout_row = DropoutRowwise(self.dropout_rate)
...@@ -153,14 +150,12 @@ class TemplatePairStackBlock(nn.Module): ...@@ -153,14 +150,12 @@ class TemplatePairStackBlock(nn.Module):
self.c_t, self.c_t,
self.c_hidden_tri_att, self.c_hidden_tri_att,
self.no_heads, self.no_heads,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
self.tri_att_end = TriangleAttentionEndingNode( self.tri_att_end = TriangleAttentionEndingNode(
self.c_t, self.c_t,
self.c_hidden_tri_att, self.c_hidden_tri_att,
self.no_heads, self.no_heads,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
...@@ -176,15 +171,20 @@ class TemplatePairStackBlock(nn.Module): ...@@ -176,15 +171,20 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition = PairTransition( self.pair_transition = PairTransition(
self.c_t, self.c_t,
self.pair_transition_n, self.pair_transition_n,
chunk_size=chunk_size,
) )
def forward(self, z, mask, _mask_trans=True): def forward(self, z, mask, chunk_size, _mask_trans=True):
z = z + self.dropout_row(self.tri_att_start(z, mask=mask)) z = z + self.dropout_row(
z = z + self.dropout_col(self.tri_att_end(z, mask=mask)) self.tri_att_start(z, chunk_size=chunk_size, mask=mask)
)
z = z + self.dropout_col(
self.tri_att_end(z, chunk_size=chunk_size, mask=mask)
)
z = z + self.dropout_row(self.tri_mul_out(z, mask=mask)) z = z + self.dropout_row(self.tri_mul_out(z, mask=mask))
z = z + self.dropout_row(self.tri_mul_in(z, mask=mask)) z = z + self.dropout_row(self.tri_mul_in(z, mask=mask))
z = z + self.pair_transition(z, mask=mask if _mask_trans else None) z = z + self.pair_transition(
z, chunk_size=chunk_size, mask=mask if _mask_trans else None
)
return z return z
...@@ -204,7 +204,6 @@ class TemplatePairStack(nn.Module): ...@@ -204,7 +204,6 @@ class TemplatePairStack(nn.Module):
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
blocks_per_ckpt, blocks_per_ckpt,
chunk_size,
inf=1e9, inf=1e9,
**kwargs, **kwargs,
): ):
...@@ -225,9 +224,6 @@ class TemplatePairStack(nn.Module): ...@@ -225,9 +224,6 @@ class TemplatePairStack(nn.Module):
blocks_per_ckpt: blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables Number of blocks per activation checkpoint. None disables
activation checkpointing activation checkpointing
chunk_size:
Size of subbatches. A higher value increases throughput at
the cost of memory
""" """
super(TemplatePairStack, self).__init__() super(TemplatePairStack, self).__init__()
...@@ -242,7 +238,6 @@ class TemplatePairStack(nn.Module): ...@@ -242,7 +238,6 @@ class TemplatePairStack(nn.Module):
no_heads=no_heads, no_heads=no_heads,
pair_transition_n=pair_transition_n, pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -253,6 +248,7 @@ class TemplatePairStack(nn.Module): ...@@ -253,6 +248,7 @@ class TemplatePairStack(nn.Module):
self, self,
t: torch.tensor, t: torch.tensor,
mask: torch.tensor, mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True, _mask_trans: bool = True,
): ):
""" """
...@@ -269,6 +265,7 @@ class TemplatePairStack(nn.Module): ...@@ -269,6 +265,7 @@ class TemplatePairStack(nn.Module):
partial( partial(
b, b,
mask=mask, mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
for b in self.blocks for b in self.blocks
......
...@@ -28,7 +28,7 @@ from openfold.utils.tensor_utils import ( ...@@ -28,7 +28,7 @@ from openfold.utils.tensor_utils import (
class TriangleAttention(nn.Module): class TriangleAttention(nn.Module):
def __init__( def __init__(
self, c_in, c_hidden, no_heads, starting, chunk_size=4, inf=1e9 self, c_in, c_hidden, no_heads, starting, inf=1e9
): ):
""" """
Args: Args:
...@@ -45,7 +45,6 @@ class TriangleAttention(nn.Module): ...@@ -45,7 +45,6 @@ class TriangleAttention(nn.Module):
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.no_heads = no_heads self.no_heads = no_heads
self.starting = starting self.starting = starting
self.chunk_size = chunk_size
self.inf = inf self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in) self.layer_norm = nn.LayerNorm(self.c_in)
...@@ -56,7 +55,7 @@ class TriangleAttention(nn.Module): ...@@ -56,7 +55,7 @@ class TriangleAttention(nn.Module):
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
) )
def forward(self, x, mask=None): def forward(self, x, chunk_size, mask=None):
""" """
Args: Args:
x: x:
...@@ -93,11 +92,11 @@ class TriangleAttention(nn.Module): ...@@ -93,11 +92,11 @@ class TriangleAttention(nn.Module):
"v_x": x, "v_x": x,
"biases": [mask_bias, triangle_bias], "biases": [mask_bias, triangle_bias],
} }
if self.chunk_size is not None: if chunk_size is not None:
x = chunk_layer( x = chunk_layer(
self.mha, self.mha,
mha_inputs, mha_inputs,
chunk_size=self.chunk_size, chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]), no_batch_dims=len(x.shape[:-2]),
) )
else: else:
......
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import logging import logging
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4" os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import random import random
import time import time
...@@ -26,7 +26,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -26,7 +26,7 @@ class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config): def __init__(self, config):
super(OpenFoldWrapper, self).__init__() super(OpenFoldWrapper, self).__init__()
self.config = config self.config = config
self.model = AlphaFold(config.model) self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss) self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(self.model, decay=config.ema.decay) self.ema = ExponentialMovingAverage(self.model, decay=config.ema.decay)
...@@ -50,6 +50,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -50,6 +50,9 @@ class OpenFoldWrapper(pl.LightningModule):
with open("prediction/preds_" + str(time.strftime("%H:%M:%S")) + ".pickle", "wb") as f: with open("prediction/preds_" + str(time.strftime("%H:%M:%S")) + ".pickle", "wb") as f:
pickle.dump(out, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(out, f, protocol=pickle.HIGHEST_PROTOCOL)
#def validation_step(self, batch, batch_idx):
# outputs = self(batch)
def configure_optimizers(self, def configure_optimizers(self,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-8 eps: float = 1e-8
...@@ -64,6 +67,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -64,6 +67,7 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model) self.ema.update(self.model)
def main(args): def main(args):
config = model_config( config = model_config(
"model_1", "model_1",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment