Commit 12caaa89 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix long sequence inference bug

parent d54b4afa
...@@ -465,7 +465,7 @@ def embed_templates_offload( ...@@ -465,7 +465,7 @@ def embed_templates_offload(
# [*, 1, N, N, C_z] # [*, 1, N, N, C_z]
t = model.template_pair_stack( t = model.template_pair_stack(
t, t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size, chunk_size=model.globals.chunk_size,
use_lma=model.globals.use_lma, use_lma=model.globals.use_lma,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment