Commit 29864369 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Squeeze templates

parent 913903e0
......@@ -138,7 +138,7 @@ class AlphaFold(nn.Module):
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
......
......@@ -448,7 +448,7 @@ def embed_templates_offload(
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment