"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "63d6540ca692d62fe5447bfe7c63233126170774"
Commit b026de28 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make chunk size more flexible, reduce verbosity

parent a59ae7c1
......@@ -45,13 +45,12 @@ def model_config(name, train=False, low_prec=False):
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e4)
set_inf(c, 1e5)
return c
......@@ -225,7 +224,8 @@ config = mlc.ConfigDict(
# Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"train_chunk_size": None,
"eval_chunk_size": chunk_size,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
......@@ -277,8 +277,7 @@ config = mlc.ConfigDict(
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
"inf": 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
......@@ -287,7 +286,6 @@ config = mlc.ConfigDict(
# It's actually 16.
"c_hidden": 16,
"no_heads": 4,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
},
"inf": 1e5, # 1e9,
......@@ -314,8 +312,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"enabled": True,
......@@ -335,8 +332,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
......
......@@ -165,7 +165,7 @@ def generate_release_dates_cache(mmcif_dir: str, out_path: str):
file_id=file_id, mmcif_string=mmcif_string
)
if mmcif.mmcif_object is None:
logging.warning(f"Failed to parse {f}. Skipping...")
logging.info(f"Failed to parse {f}. Skipping...")
continue
mmcif = mmcif.mmcif_object
......@@ -822,7 +822,7 @@ def _process_single_hit(
if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None)
else:
logging.warning(error)
logging.info(error)
return SingleHitResult(features=None, error=None, warning=None)
try:
......
......@@ -20,7 +20,7 @@ import os
import subprocess
from typing import Sequence
from openfold.data.np import utils
from openfold.data.tools import utils
class HHSearch:
......
......@@ -46,7 +46,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9
"""
def __init__(self, c_m, n, chunk_size):
def __init__(self, c_m, n):
"""
Args:
c_m:
......@@ -59,7 +59,6 @@ class MSATransition(nn.Module):
self.c_m = c_m
self.n = n
self.chunk_size = chunk_size
self.layer_norm = nn.LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
......@@ -76,6 +75,7 @@ class MSATransition(nn.Module):
self,
m: torch.Tensor,
mask: torch.Tensor = None,
chunk_size: int = None,
) -> torch.Tensor:
"""
Args:
......@@ -96,11 +96,11 @@ class MSATransition(nn.Module):
m = self.layer_norm(m)
inp = {"m": m, "mask": mask}
if self.chunk_size is not None:
if chunk_size is not None:
m = chunk_layer(
self._transition,
inp,
chunk_size=self.chunk_size,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
else:
......@@ -123,7 +123,6 @@ class EvoformerBlock(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
chunk_size: int,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
......@@ -135,7 +134,6 @@ class EvoformerBlock(nn.Module):
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
chunk_size=chunk_size,
inf=inf,
)
......@@ -144,7 +142,6 @@ class EvoformerBlock(nn.Module):
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
chunk_size=chunk_size,
inf=inf,
eps=eps,
)
......@@ -153,21 +150,18 @@ class EvoformerBlock(nn.Module):
c_m,
c_hidden_msa_att,
no_heads_msa,
chunk_size=chunk_size,
inf=inf,
)
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
chunk_size=chunk_size,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
chunk_size=chunk_size,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
......@@ -183,21 +177,18 @@ class EvoformerBlock(nn.Module):
c_z,
c_hidden_pair_att,
no_heads_pair,
chunk_size=chunk_size,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
c_z,
c_hidden_pair_att,
no_heads_pair,
chunk_size=chunk_size,
inf=inf,
)
self.pair_transition = PairTransition(
c_z,
transition_n,
chunk_size=chunk_size,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
......@@ -210,6 +201,7 @@ class EvoformerBlock(nn.Module):
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
......@@ -218,15 +210,27 @@ class EvoformerBlock(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_dropout_layer(self.msa_att_row(m, z, mask=msa_mask))
m = m + self.msa_att_col(m, mask=msa_mask)
m = m + self.msa_transition(m, mask=msa_trans_mask)
z = z + self.outer_product_mean(m, mask=msa_mask)
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_att_start(z, mask=pair_mask))
z = z + self.ps_dropout_col_layer(self.tri_att_end(z, mask=pair_mask))
z = z + self.pair_transition(z, mask=pair_trans_mask)
z = z + self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
)
z = z + self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
)
z = z + self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size
)
return m, z
......@@ -254,7 +258,6 @@ class EvoformerStack(nn.Module):
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
chunk_size: int,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
......@@ -312,7 +315,6 @@ class EvoformerStack(nn.Module):
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
chunk_size=chunk_size,
inf=inf,
eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
......@@ -328,6 +330,7 @@ class EvoformerStack(nn.Module):
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
......@@ -354,6 +357,7 @@ class EvoformerStack(nn.Module):
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
......@@ -392,7 +396,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
chunk_size: int,
inf: float,
eps: float,
**kwargs,
......@@ -415,7 +418,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
blocks_per_ckpt=blocks_per_ckpt,
chunk_size=chunk_size,
inf=inf,
eps=eps,
_is_extra_msa_stack=True,
......@@ -425,6 +427,7 @@ class ExtraMSAStack(nn.Module):
self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
......@@ -447,6 +450,7 @@ class ExtraMSAStack(nn.Module):
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return z
......@@ -63,6 +63,8 @@ class AlphaFold(nn.Module):
"""
super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa
......@@ -104,7 +106,7 @@ class AlphaFold(nn.Module):
self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim):
def embed_templates(self, batch, z, pair_mask, templ_dim, chunk_size):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
......@@ -135,14 +137,13 @@ class AlphaFold(nn.Module):
)
t = self.template_pair_embedder(t)
t = self.template_pair_stack(
t, pair_mask.unsqueeze(-3), _mask_trans=self.config._mask_trans
t,
pair_mask.unsqueeze(-3),
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans,
)
single_template_embeds.update(
{
"pair": t,
}
)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
......@@ -153,7 +154,10 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
t = self.template_pointwise_att(
template_embeds["pair"], z, template_mask=batch["template_mask"]
template_embeds["pair"],
z,
template_mask=batch["template_mask"],
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
......@@ -161,15 +165,17 @@ class AlphaFold(nn.Module):
if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update(
{
"template_pair_embedding": t,
}
)
ret.update({"template_pair_embedding": t})
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Establish constants
chunk_size = (
self.globals.train_chunk_size
if self.training else self.globals.eval_chunk_size
)
# Primary output dictionary
outputs = {}
......@@ -243,6 +249,7 @@ class AlphaFold(nn.Module):
z,
pair_mask,
no_batch_dims,
chunk_size,
)
# [*, N, N, C_z]
......@@ -270,6 +277,7 @@ class AlphaFold(nn.Module):
a,
z,
msa_mask=feats["extra_msa_mask"],
chunk_size=chunk_size,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans,
)
......@@ -283,6 +291,7 @@ class AlphaFold(nn.Module):
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans,
)
......
......@@ -34,7 +34,6 @@ class MSAAttention(nn.Module):
no_heads,
pair_bias=False,
c_z=None,
chunk_size=4,
inf=1e9,
):
"""
......@@ -60,7 +59,6 @@ class MSAAttention(nn.Module):
self.no_heads = no_heads
self.pair_bias = pair_bias
self.c_z = c_z
self.chunk_size = chunk_size
self.inf = inf
self.layer_norm_m = nn.LayerNorm(self.c_in)
......@@ -75,7 +73,7 @@ class MSAAttention(nn.Module):
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
def forward(self, m, z=None, mask=None):
def forward(self, m, chunk_size, z=None, mask=None):
"""
Args:
m:
......@@ -117,11 +115,11 @@ class MSAAttention(nn.Module):
biases.append(z)
mha_inputs = {"q_x": m, "k_x": m, "v_x": m, "biases": biases}
if self.chunk_size is not None:
if chunk_size is not None:
m = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
else:
......@@ -135,7 +133,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
Implements Algorithm 7.
"""
def __init__(self, c_m, c_z, c_hidden, no_heads, chunk_size, inf=1e9):
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
"""
Args:
c_m:
......@@ -155,7 +153,6 @@ class MSARowAttentionWithPairBias(MSAAttention):
no_heads,
pair_bias=True,
c_z=c_z,
chunk_size=chunk_size,
inf=inf,
)
......@@ -165,7 +162,7 @@ class MSAColumnAttention(MSAAttention):
Implements Algorithm 8.
"""
def __init__(self, c_m, c_hidden, no_heads, chunk_size=4, inf=1e9):
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
"""
Args:
c_m:
......@@ -183,11 +180,10 @@ class MSAColumnAttention(MSAAttention):
no_heads=no_heads,
pair_bias=False,
c_z=None,
chunk_size=chunk_size,
inf=inf,
)
def forward(self, m, mask=None):
def forward(self, m, chunk_size, mask=None):
"""
Args:
m:
......@@ -200,7 +196,7 @@ class MSAColumnAttention(MSAAttention):
if mask is not None:
mask = mask.transpose(-1, -2)
m = super().forward(m, mask=mask)
m = super().forward(m, chunk_size=chunk_size, mask=mask)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......@@ -211,14 +207,13 @@ class MSAColumnAttention(MSAAttention):
class MSAColumnGlobalAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, chunk_size=4, inf=1e9, eps=1e-10
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10
):
super(MSAColumnGlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.chunk_size = chunk_size
self.inf = inf
self.eps = eps
......@@ -233,7 +228,7 @@ class MSAColumnGlobalAttention(nn.Module):
)
def forward(
self, m: torch.Tensor, mask: Optional[torch.Tensor] = None
self, m: torch.Tensor, chunk_size, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
......@@ -256,11 +251,11 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m,
"mask": mask,
}
if self.chunk_size is not None:
if chunk_size is not None:
m = chunk_layer(
self.global_attention,
mha_input,
chunk_size=self.chunk_size,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
else:
......
......@@ -26,7 +26,7 @@ class OuterProductMean(nn.Module):
Implements Algorithm 10.
"""
def __init__(self, c_m, c_z, c_hidden, chunk_size=4, eps=1e-3):
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
"""
Args:
c_m:
......@@ -40,7 +40,6 @@ class OuterProductMean(nn.Module):
self.c_z = c_z
self.c_hidden = c_hidden
self.chunk_size = chunk_size
self.eps = eps
self.layer_norm = nn.LayerNorm(c_m)
......@@ -60,7 +59,7 @@ class OuterProductMean(nn.Module):
return outer
def forward(self, m, mask=None):
def forward(self, m, chunk_size, mask=None):
"""
Args:
m:
......@@ -84,7 +83,7 @@ class OuterProductMean(nn.Module):
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
if self.chunk_size is not None:
if chunk_size is not None:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
......@@ -95,7 +94,7 @@ class OuterProductMean(nn.Module):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{"a": a_prime},
chunk_size=self.chunk_size,
chunk_size=chunk_size,
no_batch_dims=1,
)
out.append(outer)
......
......@@ -25,7 +25,7 @@ class PairTransition(nn.Module):
Implements Algorithm 15.
"""
def __init__(self, c_z, n, chunk_size=4):
def __init__(self, c_z, n):
"""
Args:
c_z:
......@@ -38,7 +38,6 @@ class PairTransition(nn.Module):
self.c_z = c_z
self.n = n
self.chunk_size = chunk_size
self.layer_norm = nn.LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
......@@ -55,7 +54,7 @@ class PairTransition(nn.Module):
return z
def forward(self, z, mask=None):
def forward(self, z, chunk_size, mask=None):
"""
Args:
z:
......@@ -74,11 +73,11 @@ class PairTransition(nn.Module):
z = self.layer_norm(z)
inp = {"z": z, "mask": mask}
if self.chunk_size is not None:
if chunk_size is not None:
z = chunk_layer(
self._transition,
inp,
chunk_size=self.chunk_size,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
else:
......
......@@ -260,11 +260,14 @@ class Attention(nn.Module):
# [*, H, Q, K]
a = torch.matmul(
permute_final_dims(q, (0, 2, 1, 3)), # [*, H, Q, C_hidden]
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
permute_final_dims(q, (1, 0, 2)), # [*, H, Q, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, K]
)
del q, k
norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm
a *= norm
if biases is not None:
for b in biases:
a = a + b
......@@ -273,7 +276,7 @@ class Attention(nn.Module):
# [*, H, Q, C_hidden]
o = torch.matmul(
a,
permute_final_dims(v, (0, 2, 1, 3)), # [*, H, V, C_hidden]
permute_final_dims(v, (1, 0, 2)), # [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
......
......@@ -45,7 +45,7 @@ class TemplatePointwiseAttention(nn.Module):
Implements Algorithm 17.
"""
def __init__(self, c_t, c_z, c_hidden, no_heads, chunk_size, inf, **kwargs):
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
"""
Args:
c_t:
......@@ -61,7 +61,6 @@ class TemplatePointwiseAttention(nn.Module):
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.chunk_size = chunk_size
self.inf = inf
self.mha = Attention(
......@@ -73,7 +72,7 @@ class TemplatePointwiseAttention(nn.Module):
gating=False,
)
def forward(self, t, z, template_mask=None):
def forward(self, t, z, chunk_size, template_mask=None):
"""
Args:
t:
......@@ -106,11 +105,11 @@ class TemplatePointwiseAttention(nn.Module):
"v_x": t,
"biases": [bias],
}
if self.chunk_size is not None:
if chunk_size is not None:
z = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
else:
......@@ -131,7 +130,6 @@ class TemplatePairStackBlock(nn.Module):
no_heads,
pair_transition_n,
dropout_rate,
chunk_size,
inf,
**kwargs,
):
......@@ -143,7 +141,6 @@ class TemplatePairStackBlock(nn.Module):
self.no_heads = no_heads
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.chunk_size = chunk_size
self.inf = inf
self.dropout_row = DropoutRowwise(self.dropout_rate)
......@@ -153,14 +150,12 @@ class TemplatePairStackBlock(nn.Module):
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
chunk_size=chunk_size,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
chunk_size=chunk_size,
inf=inf,
)
......@@ -176,15 +171,20 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition = PairTransition(
self.c_t,
self.pair_transition_n,
chunk_size=chunk_size,
)
def forward(self, z, mask, _mask_trans=True):
z = z + self.dropout_row(self.tri_att_start(z, mask=mask))
z = z + self.dropout_col(self.tri_att_end(z, mask=mask))
def forward(self, z, mask, chunk_size, _mask_trans=True):
z = z + self.dropout_row(
self.tri_att_start(z, chunk_size=chunk_size, mask=mask)
)
z = z + self.dropout_col(
self.tri_att_end(z, chunk_size=chunk_size, mask=mask)
)
z = z + self.dropout_row(self.tri_mul_out(z, mask=mask))
z = z + self.dropout_row(self.tri_mul_in(z, mask=mask))
z = z + self.pair_transition(z, mask=mask if _mask_trans else None)
z = z + self.pair_transition(
z, chunk_size=chunk_size, mask=mask if _mask_trans else None
)
return z
......@@ -204,7 +204,6 @@ class TemplatePairStack(nn.Module):
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
chunk_size,
inf=1e9,
**kwargs,
):
......@@ -225,9 +224,6 @@ class TemplatePairStack(nn.Module):
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
chunk_size:
Size of subbatches. A higher value increases throughput at
the cost of memory
"""
super(TemplatePairStack, self).__init__()
......@@ -242,7 +238,6 @@ class TemplatePairStack(nn.Module):
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
chunk_size=chunk_size,
inf=inf,
)
self.blocks.append(block)
......@@ -253,6 +248,7 @@ class TemplatePairStack(nn.Module):
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
......@@ -269,6 +265,7 @@ class TemplatePairStack(nn.Module):
partial(
b,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
......
......@@ -28,7 +28,7 @@ from openfold.utils.tensor_utils import (
class TriangleAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, starting, chunk_size=4, inf=1e9
self, c_in, c_hidden, no_heads, starting, inf=1e9
):
"""
Args:
......@@ -45,7 +45,6 @@ class TriangleAttention(nn.Module):
self.c_hidden = c_hidden
self.no_heads = no_heads
self.starting = starting
self.chunk_size = chunk_size
self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in)
......@@ -56,7 +55,7 @@ class TriangleAttention(nn.Module):
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
def forward(self, x, mask=None):
def forward(self, x, chunk_size, mask=None):
"""
Args:
x:
......@@ -93,11 +92,11 @@ class TriangleAttention(nn.Module):
"v_x": x,
"biases": [mask_bias, triangle_bias],
}
if self.chunk_size is not None:
if chunk_size is not None:
x = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
)
else:
......
......@@ -2,7 +2,7 @@ import argparse
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import random
import time
......@@ -26,7 +26,7 @@ class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config.model)
self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(self.model, decay=config.ema.decay)
......@@ -50,6 +50,9 @@ class OpenFoldWrapper(pl.LightningModule):
with open("prediction/preds_" + str(time.strftime("%H:%M:%S")) + ".pickle", "wb") as f:
pickle.dump(out, f, protocol=pickle.HIGHEST_PROTOCOL)
#def validation_step(self, batch, batch_idx):
# outputs = self(batch)
def configure_optimizers(self,
learning_rate: float = 1e-3,
eps: float = 1e-8
......@@ -64,6 +67,7 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def main(args):
config = model_config(
"model_1",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment