"lib/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "4b6cfc1be0dea7fa5bac0f218645b92846e3a5e5"
Commit d8ee9c5f authored by Christina Floristean's avatar Christina Floristean
Browse files

All non-cuda tests passing for monomer/multimer. Tri mul/attn and OPM order switched.

parent 260db67f
.DS_Store
*.DS_Store
**/.DS_Store
.idea/
**/__pycache__
*.pyc
build/
dist/
*.egg-info/
openfold/resources
**/stereo_chemical_props.txt
**/sample_feats.pickle
...@@ -331,6 +331,7 @@ config = mlc.ConfigDict( ...@@ -331,6 +331,7 @@ config = mlc.ConfigDict(
"no_heads": 4, "no_heads": 4,
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9, "inf": 1e9,
}, },
...@@ -367,6 +368,7 @@ config = mlc.ConfigDict( ...@@ -367,6 +368,7 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": True,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
...@@ -388,6 +390,7 @@ config = mlc.ConfigDict( ...@@ -388,6 +390,7 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"inf": 1e9, "inf": 1e9,
...@@ -546,6 +549,7 @@ multimer_model_config_update = { ...@@ -546,6 +549,7 @@ multimer_model_config_update = {
"no_heads": 4, "no_heads": 4,
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": True,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9, "inf": 1e9,
}, },
...@@ -555,6 +559,53 @@ multimer_model_config_update = { ...@@ -555,6 +559,53 @@ multimer_model_config_update = {
"eps": eps, # 1e-6, "eps": eps, # 1e-6,
"enabled": templates_enabled, "enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles, "embed_angles": embed_template_torsion_angles,
"use_unit_vector": True
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
}, },
"heads": { "heads": {
"lddt": { "lddt": {
......
...@@ -93,7 +93,7 @@ def fix_templates_aatype(protein): ...@@ -93,7 +93,7 @@ def fix_templates_aatype(protein):
# Map hhsearch-aatype to our aatype. # Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor( new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["aatype"].device, new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
).expand(num_templates, -1) ).expand(num_templates, -1)
protein["template_aatype"] = torch.gather( protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"] new_order, 1, index=protein["template_aatype"]
...@@ -669,8 +669,8 @@ def make_atom14_masks(protein): ...@@ -669,8 +669,8 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch): def make_atom14_masks_np(batch):
batch = tree_map( batch = tree_map(
lambda n: torch.tensor(n, device=batch["aatype"].device), lambda n: torch.tensor(n, device="cpu"),
batch, batch,
np.ndarray np.ndarray
) )
out = make_atom14_masks(batch) out = make_atom14_masks(batch)
......
...@@ -1048,7 +1048,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1048,7 +1048,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for i in idx: for i in idx:
# We got all the templates we wanted, stop processing hits. # We got all the templates we wanted, stop processing hits.
if len(already_seen) >= self.max_hits: if len(already_seen) >= self._max_hits:
break break
hit = filtered[i] hit = filtered[i]
...@@ -1088,16 +1088,29 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1088,16 +1088,29 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for k in template_features: for k in template_features:
template_features[k].append(result.features[k]) template_features[k].append(result.features[k])
for name in template_features: if already_seen:
if num_hits > 0: for name in template_features:
template_features[name] = np.stack( template_features[name] = np.stack(
template_features[name], axis=0 template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name]) ).astype(TEMPLATE_FEATURES[name])
else: else:
# Make sure the feature has correct dtype even if empty. num_res = len(query_sequence)
template_features[name] = np.array( # Construct a default template with all zeros.
[], dtype=TEMPLATE_FEATURES[name] template_features = {
) "template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings features=template_features, errors=errors, warnings=warnings
......
...@@ -216,10 +216,12 @@ class InputEmbedderMultimer(nn.Module): ...@@ -216,10 +216,12 @@ class InputEmbedderMultimer(nn.Module):
(2 * self.max_relative_idx + 1) * (2 * self.max_relative_idx + 1) *
torch.ones_like(clipped_offset) torch.ones_like(clipped_offset)
) )
boundaries = torch.arange(
rel_pos = torch.nn.functional.one_hot( start=0, end=2 * self.max_relative_idx + 2, device=final_offset.device
)
rel_pos = one_hot(
final_offset, final_offset,
2 * self.max_relative_idx + 2, boundaries,
) )
rel_feats.append(rel_pos) rel_feats.append(rel_pos)
...@@ -245,15 +247,21 @@ class InputEmbedderMultimer(nn.Module): ...@@ -245,15 +247,21 @@ class InputEmbedderMultimer(nn.Module):
torch.ones_like(clipped_rel_chain) torch.ones_like(clipped_rel_chain)
) )
rel_chain = torch.nn.functional.one_hot( boundaries = torch.arange(
start=0, end=2 * max_rel_chain + 2, device=final_rel_chain.device
)
rel_chain = one_hot(
final_rel_chain, final_rel_chain,
2 * max_rel_chain + 2, boundaries,
) )
rel_feats.append(rel_chain) rel_feats.append(rel_chain)
else: else:
rel_pos = torch.nn.functional.one_hot( boundaries = torch.arange(
clipped_offset, 2 * self.max_relative_idx + 1, start=0, end=2 * self.max_relative_idx + 1, device=clipped_offset.device
)
rel_pos = one_hot(
clipped_offset, boundaries,
) )
rel_feats.append(rel_pos) rel_feats.append(rel_pos)
...@@ -471,102 +479,6 @@ class TemplatePairEmbedder(nn.Module): ...@@ -471,102 +479,6 @@ class TemplatePairEmbedder(nn.Module):
return x return x
class TemplateEmbedder(nn.Module):
def __init__(
self,
config,
):
super().__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.embed_angles:
ret["template_pair_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
class ExtraMSAEmbedder(nn.Module): class ExtraMSAEmbedder(nn.Module):
""" """
Embeds unclustered MSA sequences. Embeds unclustered MSA sequences.
...@@ -625,12 +537,13 @@ class TemplateEmbedder(nn.Module): ...@@ -625,12 +537,13 @@ class TemplateEmbedder(nn.Module):
**config["template_pointwise_attention"], **config["template_pointwise_attention"],
) )
def forward(self, def forward(
batch, self,
batch,
z, z,
pair_mask, pair_mask,
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True _mask_trans=True
): ):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
...@@ -706,7 +619,7 @@ class TemplatePairEmbedderMultimer(nn.Module): ...@@ -706,7 +619,7 @@ class TemplatePairEmbedderMultimer(nn.Module):
c_dgram: int, c_dgram: int,
c_aatype: int, c_aatype: int,
): ):
super().__init__() super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out) self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out) self.aatype_linear_1 = Linear(c_aatype, c_out)
...@@ -765,7 +678,7 @@ class TemplateSingleEmbedderMultimer(nn.Module): ...@@ -765,7 +678,7 @@ class TemplateSingleEmbedderMultimer(nn.Module):
c_in: int, c_in: int,
c_m: int, c_m: int,
): ):
super().__init__() super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_m) self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m) self.template_projector = Linear(c_m, c_m)
......
...@@ -117,34 +117,19 @@ class MSATransition(nn.Module): ...@@ -117,34 +117,19 @@ class MSATransition(nn.Module):
return m return m
class EvoformerBlockCore(nn.Module): class PairStack(nn.Module):
def __init__( def __init__(
self, self,
c_m: int,
c_z: int, c_z: int,
c_hidden_opm: int,
c_hidden_mul: int, c_hidden_mul: int,
c_hidden_pair_att: int, c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int, no_heads_pair: int,
transition_n: int, transition_n: int,
pair_dropout: float, pair_dropout: float,
inf: float, inf: float,
eps: float, eps: float
_is_extra_msa_stack: bool = False,
): ):
super(EvoformerBlockCore, self).__init__() super(PairStack, self).__init__()
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.tri_mul_out = TriangleMultiplicationOutgoing( self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z, c_z,
...@@ -178,25 +163,15 @@ class EvoformerBlockCore(nn.Module): ...@@ -178,25 +163,15 @@ class EvoformerBlockCore(nn.Module):
def forward( def forward(
self, self,
m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of # should be disabled to better approximate the exact activations of
# the original. # the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer( z = z + self.ps_dropout_row_layer(
...@@ -209,7 +184,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -209,7 +184,7 @@ class EvoformerBlockCore(nn.Module):
z, mask=pair_trans_mask, chunk_size=chunk_size z, mask=pair_trans_mask, chunk_size=chunk_size
) )
return m, z return z
class EvoformerBlock(nn.Module): class EvoformerBlock(nn.Module):
...@@ -225,11 +200,14 @@ class EvoformerBlock(nn.Module): ...@@ -225,11 +200,14 @@ class EvoformerBlock(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
opm_first: bool,
inf: float, inf: float,
eps: float, eps: float,
): ):
super(EvoformerBlock, self).__init__() super(EvoformerBlock, self).__init__()
self.opm_first = opm_first
self.msa_att_row = MSARowAttentionWithPairBias( self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m, c_m=c_m,
c_z=c_z, c_z=c_z,
...@@ -247,18 +225,26 @@ class EvoformerBlock(nn.Module): ...@@ -247,18 +225,26 @@ class EvoformerBlock(nn.Module):
self.msa_dropout_layer = DropoutRowwise(msa_dropout) self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore( self.msa_transition = MSATransition(
c_m=c_m, c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.pair_stack = PairStack(
c_z=c_z, c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul, c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att, c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair, no_heads_pair=no_heads_pair,
transition_n=transition_n, transition_n=transition_n,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
inf=inf, inf=inf,
eps=eps, eps=eps
) )
def forward(self, def forward(self,
...@@ -269,17 +255,34 @@ class EvoformerBlock(nn.Module): ...@@ -269,17 +255,34 @@ class EvoformerBlock(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
if self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
m = m + self.msa_dropout_layer( m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
) )
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m, m = m + self.msa_transition(
z, m, mask=msa_trans_mask, chunk_size=chunk_size
msa_mask=msa_mask, )
pair_mask=pair_mask,
chunk_size=chunk_size, if not self.opm_first:
_mask_trans=_mask_trans, z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = self.pair_stack(
z,
pair_mask=pair_mask,
chunk_size=chunk_size,
) )
return m, z return m, z
...@@ -304,12 +307,14 @@ class ExtraMSABlock(nn.Module): ...@@ -304,12 +307,14 @@ class ExtraMSABlock(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
opm_first: bool,
inf: float, inf: float,
eps: float, eps: float,
ckpt: bool, ckpt: bool,
): ):
super(ExtraMSABlock, self).__init__() super(ExtraMSABlock, self).__init__()
self.opm_first = opm_first
self.ckpt = ckpt self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias( self.msa_att_row = MSARowAttentionWithPairBias(
...@@ -330,13 +335,21 @@ class ExtraMSABlock(nn.Module): ...@@ -330,13 +335,21 @@ class ExtraMSABlock(nn.Module):
self.msa_dropout_layer = DropoutRowwise(msa_dropout) self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore( self.msa_transition = MSATransition(
c_m=c_m, c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.pair_stack = PairStack(
c_z=c_z, c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul, c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att, c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair, no_heads_pair=no_heads_pair,
transition_n=transition_n, transition_n=transition_n,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
...@@ -361,6 +374,11 @@ class ExtraMSABlock(nn.Module): ...@@ -361,6 +374,11 @@ class ExtraMSABlock(nn.Module):
m1 += m2 m1 += m2
return m1 return m1
if self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
m = add(m, self.msa_dropout_layer( m = add(m, self.msa_dropout_layer(
self.msa_att_row( self.msa_att_row(
...@@ -377,8 +395,17 @@ class ExtraMSABlock(nn.Module): ...@@ -377,8 +395,17 @@ class ExtraMSABlock(nn.Module):
def fn(m, z): def fn(m, z):
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)) m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m, z = self.core( m = add(m, self.msa_transition(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size m, mask=msa_mask, chunk_size=chunk_size
))
if not self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = self.pair_stack(
z, pair_mask=pair_mask, chunk_size=chunk_size
) )
return m, z return m, z
...@@ -414,6 +441,7 @@ class EvoformerStack(nn.Module): ...@@ -414,6 +441,7 @@ class EvoformerStack(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
opm_first: bool,
blocks_per_ckpt: int, blocks_per_ckpt: int,
inf: float, inf: float,
eps: float, eps: float,
...@@ -475,6 +503,7 @@ class EvoformerStack(nn.Module): ...@@ -475,6 +503,7 @@ class EvoformerStack(nn.Module):
transition_n=transition_n, transition_n=transition_n,
msa_dropout=msa_dropout, msa_dropout=msa_dropout,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
opm_first=opm_first,
inf=inf, inf=inf,
eps=eps, eps=eps,
) )
...@@ -555,6 +584,7 @@ class ExtraMSAStack(nn.Module): ...@@ -555,6 +584,7 @@ class ExtraMSAStack(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
opm_first: bool,
inf: float, inf: float,
eps: float, eps: float,
ckpt: bool, ckpt: bool,
...@@ -581,6 +611,7 @@ class ExtraMSAStack(nn.Module): ...@@ -581,6 +611,7 @@ class ExtraMSAStack(nn.Module):
transition_n=transition_n, transition_n=transition_n,
msa_dropout=msa_dropout, msa_dropout=msa_dropout,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
opm_first=opm_first,
inf=inf, inf=inf,
eps=eps, eps=eps,
ckpt=ckpt if chunk_msa_attn else False, ckpt=ckpt if chunk_msa_attn else False,
......
...@@ -169,8 +169,8 @@ class PointProjection(nn.Module): ...@@ -169,8 +169,8 @@ class PointProjection(nn.Module):
def forward(self, def forward(self,
activations: torch.Tensor, activations: torch.Tensor,
rigids: Rigid3Array, rigids: Union[Rigid, Rigid3Array],
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]: ) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array], torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training # TODO: Needs to run in high precision during training
points_local = self.linear(activations) points_local = self.linear(activations)
points_local = points_local.reshape( points_local = points_local.reshape(
...@@ -181,8 +181,9 @@ class PointProjection(nn.Module): ...@@ -181,8 +181,9 @@ class PointProjection(nn.Module):
points_local = torch.split( points_local = torch.split(
points_local, points_local.shape[-1] // 3, dim=-1 points_local, points_local.shape[-1] // 3, dim=-1
) )
points_local = Vec3Array(*points_local)
points_global = rigids[..., None, None].apply_to_point(points_local) points_local = torch.stack(points_local, dim=-1)
points_global = rigids[..., None, None].apply(points_local)
if(self.return_local_points): if(self.return_local_points):
return points_global, points_local return points_global, points_local
...@@ -285,7 +286,7 @@ class InvariantPointAttention(nn.Module): ...@@ -285,7 +286,7 @@ class InvariantPointAttention(nn.Module):
self, self,
s: torch.Tensor, s: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
r: Rigid, r: Union[Rigid, Rigid3Array],
mask: torch.Tensor, mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -340,9 +341,6 @@ class InvariantPointAttention(nn.Module): ...@@ -340,9 +341,6 @@ class InvariantPointAttention(nn.Module):
k, v = torch.split(kv, self.c_hidden, dim=-1) k, v = torch.split(kv, self.c_hidden, dim=-1)
kv_pts = self.linear_kv_points(s, r) kv_pts = self.linear_kv_points(s, r)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3] # [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split( k_pts, v_pts = torch.split(
...@@ -364,10 +362,16 @@ class InvariantPointAttention(nn.Module): ...@@ -364,10 +362,16 @@ class InvariantPointAttention(nn.Module):
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :] if self.is_multimer:
pt_att = q_pts.unsqueeze(-3) - k_pts.unsqueeze(-4)
# [*, N_res, N_res, H, P_q]
pt_att = sum([c**2 for c in pt_att]) # [*, N_res, N_res, H, P_q]
pt_att = sum([c ** 2 for c in pt_att])
else:
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view( head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1)) *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
) )
...@@ -399,20 +403,42 @@ class InvariantPointAttention(nn.Module): ...@@ -399,20 +403,42 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2) o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation if self.is_multimer:
# happens in float32. # As DeepMind explains, this manual matmul ensures that the operation
# [*, N_res, H, P_v] # happens in float32.
o_pt = v_pts * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1) # [*, N_res, H, P_v]
o_pt = o_pt.sum(dim=-3) o_pt = v_pts[..., None, :, :, :] * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v] # [*, N_res, H, P_v]
o_pt = r[..., None, None].apply_inverse_to_point(o_pt) o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H * P_v, 3] # [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,)) o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps) o_pt_norm = o_pt.norm(self.eps)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
...@@ -617,7 +643,10 @@ class StructureModule(nn.Module): ...@@ -617,7 +643,10 @@ class StructureModule(nn.Module):
self.dropout_rate, self.dropout_rate,
) )
self.bb_update = QuatRigid(self.c_s, full_quat=False) if self.is_multimer:
self.bb_update = QuatRigid(self.c_s, full_quat=False)
else:
self.bb_update = BackboneUpdate(self.c_s)
self.angle_resnet = AngleResnet( self.angle_resnet = AngleResnet(
self.c_s, self.c_s,
......
...@@ -141,6 +141,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -141,6 +141,7 @@ class TemplatePairStackBlock(nn.Module):
no_heads: int, no_heads: int,
pair_transition_n: int, pair_transition_n: int,
dropout_rate: float, dropout_rate: float,
tri_mul_first: bool,
inf: float, inf: float,
**kwargs, **kwargs,
): ):
...@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.inf = inf self.inf = inf
self.tri_mul_first = tri_mul_first
self.dropout_row = DropoutRowwise(self.dropout_rate) self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate) self.dropout_col = DropoutColumnwise(self.dropout_rate)
...@@ -184,6 +186,38 @@ class TemplatePairStackBlock(nn.Module): ...@@ -184,6 +186,38 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n, self.pair_transition_n,
) )
def tri_att_start_end(self, single, single_mask, chunk_size):
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
return single
def tri_mul_out_in(self, single, single_mask):
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_in(
single,
mask=single_mask
)
)
return single
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
...@@ -200,32 +234,17 @@ class TemplatePairStackBlock(nn.Module): ...@@ -200,32 +234,17 @@ class TemplatePairStackBlock(nn.Module):
single = single_templates[i] single = single_templates[i]
single_mask = single_templates_masks[i] single_mask = single_templates_masks[i]
single = single + self.dropout_row( if self.tri_mul_first:
self.tri_att_start( single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
single, single_mask=single_mask),
chunk_size=chunk_size, single_mask=single_mask,
mask=single_mask chunk_size=chunk_size)
) else:
) single = self.tri_mul_out_in(single=self.tri_att_start_end(single=single,
single = single + self.dropout_col( single_mask=single_mask,
self.tri_att_end( chunk_size=chunk_size),
single, single_mask=single_mask)
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_in(
single,
mask=single_mask
)
)
single = single + self.pair_transition( single = single + self.pair_transition(
single, single,
mask=single_mask if _mask_trans else None, mask=single_mask if _mask_trans else None,
...@@ -252,6 +271,7 @@ class TemplatePairStack(nn.Module): ...@@ -252,6 +271,7 @@ class TemplatePairStack(nn.Module):
no_heads, no_heads,
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
tri_mul_first,
blocks_per_ckpt, blocks_per_ckpt,
inf=1e9, inf=1e9,
**kwargs, **kwargs,
...@@ -287,6 +307,7 @@ class TemplatePairStack(nn.Module): ...@@ -287,6 +307,7 @@ class TemplatePairStack(nn.Module):
no_heads=no_heads, no_heads=no_heads,
pair_transition_n=pair_transition_n, pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first,
inf=inf, inf=inf,
) )
self.blocks.append(block) self.blocks.append(block)
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict from typing import Dict, Union
from openfold.np import protein from openfold.np import protein
import openfold.np.residue_constants as rc import openfold.np.residue_constants as rc
...@@ -179,11 +179,11 @@ def build_extra_msa_feat(batch): ...@@ -179,11 +179,11 @@ def build_extra_msa_feat(batch):
batch["extra_has_deletion"].unsqueeze(-1), batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1), batch["extra_deletion_value"].unsqueeze(-1),
] ]
return msa_feat return torch.cat(msa_feat, dim=-1)
def torsion_angles_to_frames( def torsion_angles_to_frames(
r: Rigid, r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
alpha: torch.Tensor, alpha: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
rrgdf: torch.Tensor, rrgdf: torch.Tensor,
...@@ -220,8 +220,14 @@ def torsion_angles_to_frames( ...@@ -220,8 +220,14 @@ def torsion_angles_to_frames(
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:] = alpha
all_rots = rotation_matrix.Rot3Array.from_array(all_rots) if isinstance(r, Rigid):
all_frames = default_r.compose_rotation(all_rots) rigid_type = Rigid
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
else:
rigid_type = rigid_matrix_vector.Rigid3Array
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
...@@ -232,7 +238,7 @@ def torsion_angles_to_frames( ...@@ -232,7 +238,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = rigid_matrix_vector.Rigid3Array.cat( all_frames_to_bb = rigid_type.cat(
[ [
all_frames[..., :5], all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1), chi2_frame_to_bb.unsqueeze(-1),
...@@ -248,7 +254,7 @@ def torsion_angles_to_frames( ...@@ -248,7 +254,7 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
r: rigid_matrix_vector.Rigid3Array, r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
aatype: torch.Tensor, aatype: torch.Tensor,
default_frames, default_frames,
group_idx, group_idx,
...@@ -277,6 +283,8 @@ def frames_and_literature_positions_to_atom14_pos( ...@@ -277,6 +283,8 @@ def frames_and_literature_positions_to_atom14_pos(
# [*, N, 14] # [*, N, 14]
atom_mask = atom_mask[aatype, ...] atom_mask = atom_mask[aatype, ...]
if isinstance(r, Rigid):
atom_mask = atom_mask.unsqueeze(-1)
# [*, N, 14, 3] # [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...] lit_positions = lit_positions[aatype, ...]
......
...@@ -142,7 +142,7 @@ class Rigid3Array: ...@@ -142,7 +142,7 @@ class Rigid3Array:
def reshape(self, new_shape) -> Rigid3Array: def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape) rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape) trans = self.translation.reshape(new_shape)
return Rigid3Aray(rots, trans) return Rigid3Array(rots, trans)
def stop_rot_gradient(self) -> Rigid3Array: def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array( return Rigid3Array(
...@@ -174,3 +174,6 @@ class Rigid3Array: ...@@ -174,3 +174,6 @@ class Rigid3Array:
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3] array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
) )
return cls(rotation, translation) return cls(rotation, translation)
def cuda(self) -> Rigid3Array:
return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import List
import torch import torch
import numpy as np import numpy as np
...@@ -172,10 +173,10 @@ class Rot3Array: ...@@ -172,10 +173,10 @@ class Rot3Array:
"""Construct Rot3Array from components of quaternion.""" """Construct Rot3Array from components of quaternion."""
if normalize: if normalize:
inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2) inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2)
w *= inv_norm w = w * inv_norm
x *= inv_norm x = x * inv_norm
y *= inv_norm y = y * inv_norm
z *= inv_norm z = z * inv_norm
xx = 1 - 2 * (y ** 2 + z ** 2) xx = 1 - 2 * (y ** 2 + z ** 2)
xy = 2 * (x * y - w * z) xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y) xz = 2 * (x * z + w * y)
......
...@@ -110,7 +110,7 @@ class Vec3Array: ...@@ -110,7 +110,7 @@ class Vec3Array:
# To avoid NaN on the backward pass, we must use maximum before the sqrt # To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self) norm2 = self.dot(self)
if epsilon: if epsilon:
norm2 = torch.clamp(norm2, max=epsilon**2) norm2 = torch.clamp(norm2, min=epsilon**2)
return torch.sqrt(norm2) return torch.sqrt(norm2)
def norm2(self): def norm2(self):
......
...@@ -129,7 +129,7 @@ def assign(translation_dict, orig_weights): ...@@ -129,7 +129,7 @@ def assign(translation_dict, orig_weights):
raise raise
def get_translation_dict(model, is_multimer=False): def get_translation_dict(model, version, is_multimer=False):
####################### #######################
# Some templates # Some templates
####################### #######################
...@@ -247,7 +247,7 @@ def get_translation_dict(model, is_multimer=False): ...@@ -247,7 +247,7 @@ def get_translation_dict(model, is_multimer=False):
) )
IPAParams = lambda ipa: { IPAParams = lambda ipa: {
"q_scalar_projection": LinearParams(ipa.linear_q), "q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv), "kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points.linear), "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points.linear), "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
...@@ -331,19 +331,19 @@ def get_translation_dict(model, is_multimer=False): ...@@ -331,19 +331,19 @@ def get_translation_dict(model, is_multimer=False):
b.msa_att_row b.msa_att_row
), ),
col_att_name: msa_col_att_params, col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition), "msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": "outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean), OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing": "triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out), TriMulOutParams(b.pair_stack.tri_mul_out),
"triangle_multiplication_incoming": "triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in), TriMulInParams(b.pair_stack.tri_mul_in),
"triangle_attention_starting_node": "triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start), TriAttParams(b.pair_stack.tri_att_start),
"triangle_attention_ending_node": "triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end), TriAttParams(b.pair_stack.tri_att_end),
"pair_transition": "pair_transition":
PairTransitionParams(b.core.pair_transition), PairTransitionParams(b.pair_stack.pair_transition),
} }
return d return d
...@@ -584,17 +584,6 @@ def get_translation_dict(model, is_multimer=False): ...@@ -584,17 +584,6 @@ def get_translation_dict(model, is_multimer=False):
}, },
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
is_multimer=("multimer" in version)
)
no_templ = [ no_templ = [
"model_3", "model_3",
"model_4", "model_4",
...@@ -615,6 +604,18 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -615,6 +604,18 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
version,
is_multimer=("multimer" in version)
)
# Flatten keys and insert missing key prefixes # Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations) flat = _process_translations_dict(translations)
......
...@@ -636,9 +636,7 @@ def compute_tm( ...@@ -636,9 +636,7 @@ def compute_tm(
) )
bin_centers = _calculate_bin_centers(boundaries) bin_centers = _calculate_bin_centers(boundaries)
soft_n = torch.sum(residue_weights, dim=-1).to(torch.int32) clipped_n = max(torch.sum(residue_weights), 19)
other = n.new_zeros() + 19
clipped_n = torch.max(soft_n, other, dim=-1)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
......
...@@ -986,6 +986,16 @@ class Rigid: ...@@ -986,6 +986,16 @@ class Rigid:
""" """
return self._trans.device return self._trans.device
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the Rigid tensors.
Returns:
The dtype of the Rigid tensors
"""
return self._rots.dtype
def get_rots(self) -> Rotation: def get_rots(self) -> Rotation:
""" """
Getter for the rotation. Getter for the rotation.
......
...@@ -46,26 +46,26 @@ def import_alphafold(): ...@@ -46,26 +46,26 @@ def import_alphafold():
def get_alphafold_config(): def get_alphafold_config():
config = alphafold.model.config.model_config("model_1_ptm") # noqa config = alphafold.model.config.model_config(consts.model) # noqa
config.model.global_config.deterministic = True config.model.global_config.deterministic = True
return config return config
_param_path = "openfold/resources/params/params_model_1_ptm.npz" _param_path = f"openfold/resources/params/params_{consts.model}.npz"
_model = None _model = None
def get_global_pretrained_openfold(): def get_global_pretrained_openfold():
global _model global _model
if _model is None: if _model is None:
_model = AlphaFold(model_config("model_1_ptm")) _model = AlphaFold(model_config(consts.model))
_model = _model.eval() _model = _model.eval()
if not os.path.exists(_param_path): if not os.path.exists(_param_path):
raise FileNotFoundError( raise FileNotFoundError(
"""Cannot load pretrained parameters. Make sure to run the """Cannot load pretrained parameters. Make sure to run the
installation script before running tests.""" installation script before running tests."""
) )
import_jax_weights_(_model, _param_path, version="model_1_ptm") import_jax_weights_(_model, _param_path, version=consts.model)
_model = _model.cuda() _model = _model.cuda()
return _model return _model
......
...@@ -2,6 +2,9 @@ import ml_collections as mlc ...@@ -2,6 +2,9 @@ import ml_collections as mlc
consts = mlc.ConfigDict( consts = mlc.ConfigDict(
{ {
"model": "model_1_multimer_v2", # monomer:model_1_ptm, multimer: model_1_multimer_v2
"is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2, "batch_size": 2,
"n_res": 11, "n_res": 11,
"n_seq": 13, "n_seq": 13,
...@@ -15,6 +18,7 @@ consts = mlc.ConfigDict( ...@@ -15,6 +18,7 @@ consts = mlc.ConfigDict(
"c_s": 384, "c_s": 384,
"c_t": 64, "c_t": 64,
"c_e": 64, "c_e": 64,
"msa_logits": 22 # monomer: 23, multimer: 22
} }
) )
......
...@@ -12,9 +12,31 @@ ...@@ -12,9 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from random import randint
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
from tests.config import consts
def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
n_chain = randint(1, n_res // min_chain_len) if consts.is_multimer else 1
if not split_chains:
return [0] * n_res
assert n_res >= n_chain
pieces = []
asym_ids = []
for idx in range(n_chain - 1):
piece = randint(min_chain_len, (n_res - sum(pieces) - n_chain + idx - min_chain_len))
pieces.append(piece)
asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [n_chain - 1])
return np.array(asym_ids).astype(np.int64)
def random_template_feats(n_templ, n, batch_size=None): def random_template_feats(n_templ, n, batch_size=None):
b = [] b = []
...@@ -39,6 +61,11 @@ def random_template_feats(n_templ, n, batch_size=None): ...@@ -39,6 +61,11 @@ def random_template_feats(n_templ, n, batch_size=None):
} }
batch = {k: v.astype(np.float32) for k, v in batch.items()} batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64) batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
if consts.is_multimer:
asym_ids = np.array(random_asym_ids(n))
batch["asym_id"] = np.tile(asym_ids[np.newaxis, :], (*b, n_templ, 1))
return batch return batch
......
...@@ -15,19 +15,13 @@ ...@@ -15,19 +15,13 @@
import pickle import pickle
import shutil import shutil
import torch
import numpy as np import numpy as np
import unittest import unittest
from openfold.data.data_pipeline import DataPipeline from openfold.data.data_pipeline import DataPipeline
from openfold.data.templates import TemplateHitFeaturizer from openfold.data.templates import HhsearchHitFeaturizer, HmmsearchHitFeaturizer
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
)
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed(): if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold() alphafold = compare_utils.import_alphafold()
...@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase): ...@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp: with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp:
alphafold_feature_dict = pickle.load(fp) alphafold_feature_dict = pickle.load(fp)
template_featurizer = TemplateHitFeaturizer( if consts.is_multimer:
mmcif_dir="tests/test_data/mmcifs", # template_featurizer = HmmsearchHitFeaturizer(
max_template_date="2021-12-20", # mmcif_dir="tests/test_data/mmcifs",
max_hits=20, # max_template_date="2021-12-20",
kalign_binary_path=shutil.which("kalign"), # max_hits=20,
_zero_center_positions=False, # kalign_binary_path=shutil.which("kalign"),
) # _zero_center_positions=False,
# )
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
else:
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
data_pipeline = DataPipeline( data_pipeline = DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
......
import copy
import gzip import gzip
import os
import pickle import pickle
import numpy as np import numpy as np
...@@ -178,7 +174,7 @@ class TestDataTransforms(unittest.TestCase): ...@@ -178,7 +174,7 @@ class TestDataTransforms(unittest.TestCase):
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)} protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = make_hhblits_profile(protein) protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15) protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
assert 'bert_mask' in protein assert 'bert_mask' in protein
assert 'true_msa' in protein assert 'true_msa' in protein
assert 'msa' in protein assert 'msa' in protein
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment