Commit aea444ed authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Move template pair stack out of single-template block (for DeepSpeed)

parent 896f8935
......@@ -106,7 +106,7 @@ class AlphaFold(nn.Module):
self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim, chunk_size):
def embed_templates(self, batch, z, pair_mask, templ_dim, chunk_size):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
......@@ -136,12 +136,6 @@ class AlphaFold(nn.Module):
**self.config.template.distogram,
)
t = self.template_pair_embedder(t)
t = self.template_pair_stack(
t,
pair_mask.unsqueeze(-3),
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans,
)
single_template_embeds.update({"pair": t})
......@@ -153,8 +147,14 @@ class AlphaFold(nn.Module):
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3),
chunk_size=chunk_size,
_mask_trans=self.config._mask_trans,
)
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"],
chunk_size=chunk_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment