Commit 53bb9c10 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix mask casting

parent a8601529
...@@ -149,7 +149,7 @@ class AlphaFold(nn.Module): ...@@ -149,7 +149,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["pair"], template_embeds["pair"],
pair_mask.unsqueeze(-3), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -158,7 +158,7 @@ class AlphaFold(nn.Module): ...@@ -158,7 +158,7 @@ class AlphaFold(nn.Module):
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"], template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
...@@ -246,34 +246,32 @@ class AlphaFold(nn.Module): ...@@ -246,34 +246,32 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_mask = feats["template_mask"] template_feats = {
if(torch.any(template_mask)): k: v for k, v in feats.items() if k.startswith("template_")
template_feats = { }
k: v for k, v in feats.items() if k.startswith("template_") template_embeds = self.embed_templates(
} template_feats,
template_embeds = self.embed_templates( z,
template_feats, pair_mask.to(dtype=z.dtype),
z, no_batch_dims,
pair_mask, )
no_batch_dims,
)
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles: if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_angle_embedding"]],
dim=-3 dim=-3
) )
# [*, S, N] # [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], [feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2 dim=-2
) )
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
...@@ -284,9 +282,9 @@ class AlphaFold(nn.Module): ...@@ -284,9 +282,9 @@ class AlphaFold(nn.Module):
z = self.extra_msa_stack( z = self.extra_msa_stack(
a, a,
z, z,
msa_mask=feats["extra_msa_mask"], msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
pair_mask=pair_mask, pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -297,8 +295,8 @@ class AlphaFold(nn.Module): ...@@ -297,8 +295,8 @@ class AlphaFold(nn.Module):
m, z, s = self.evoformer( m, z, s = self.evoformer(
m, m,
z, z,
msa_mask=msa_mask, msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask, pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -312,7 +310,7 @@ class AlphaFold(nn.Module): ...@@ -312,7 +310,7 @@ class AlphaFold(nn.Module):
s, s,
z, z,
feats["aatype"], feats["aatype"],
mask=feats["seq_mask"], mask=feats["seq_mask"].to(dtype=s.dtype),
) )
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment