Commit bebbe782 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Speed up inference

parent 6e66b218
......@@ -177,10 +177,9 @@ class AlphaFold(nn.Module):
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
)
if(inplace_safe):
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
......@@ -199,6 +198,8 @@ class AlphaFold(nn.Module):
ret.update({"template_pair_embedding": t})
del t
return ret
def iteration(self, feats, prevs, _recycle=True):
......
......@@ -102,7 +102,8 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
# This module suffers greatly from a small chunk size
chunk_size: Optional[int] = 256,
use_lma: bool = False,
) -> torch.Tensor:
"""
......@@ -209,13 +210,14 @@ class TemplatePairStackBlock(nn.Module):
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = add(single,
self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask,
use_memory_efficient_kernel=not use_lma,
use_lma=use_lma,
)
),
......@@ -228,6 +230,7 @@ class TemplatePairStackBlock(nn.Module):
single,
chunk_size=chunk_size,
mask=single_mask,
use_memory_efficient_kernel=not use_lma,
use_lma=use_lma,
)
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment