"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "df299e7c2c215fa11a1dba08cc8f5904a370755b"
Commit 53bb9c10 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix mask casting

parent a8601529
......@@ -149,7 +149,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3),
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
......@@ -158,7 +158,7 @@ class AlphaFold(nn.Module):
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"],
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
......@@ -246,34 +246,32 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
template_mask = feats["template_mask"]
if(torch.any(template_mask)):
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
)
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
# [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
......@@ -284,9 +282,9 @@ class AlphaFold(nn.Module):
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"],
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask,
pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans,
)
......@@ -297,8 +295,8 @@ class AlphaFold(nn.Module):
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
......@@ -312,7 +310,7 @@ class AlphaFold(nn.Module):
s,
z,
feats["aatype"],
mask=feats["seq_mask"],
mask=feats["seq_mask"].to(dtype=s.dtype),
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment