Commit aea444ed authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Move template pair stack out of single-template block (for DeepSpeed)

parent 896f8935
......@@ -136,12 +136,6 @@ class AlphaFold(nn.Module):
**self.config.template.distogram,
)
t = self.template_pair_embedder(t)
t = self.template_pair_stack(
t,
pair_mask.unsqueeze(-3),
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans,
)
single_template_embeds.update({"pair": t})
......@@ -153,8 +147,14 @@ class AlphaFold(nn.Module):
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3),
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans,
)
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"],
chunk_size=chunk_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment