Commit bebbe782 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Speed up inference

parent 6e66b218
......@@ -177,7 +177,6 @@ class AlphaFold(nn.Module):
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
)
......@@ -199,6 +198,8 @@ class AlphaFold(nn.Module):
ret.update({"template_pair_embedding": t})
del t
return ret
def iteration(self, feats, prevs, _recycle=True):
......
......@@ -102,7 +102,8 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
# This module suffers greatly from a small chunk size
chunk_size: Optional[int] = 256,
use_lma: bool = False,
) -> torch.Tensor:
"""
......@@ -216,6 +217,7 @@ class TemplatePairStackBlock(nn.Module):
single,
chunk_size=chunk_size,
mask=single_mask,
use_memory_efficient_kernel=not use_lma,
use_lma=use_lma,
)
),
......@@ -228,6 +230,7 @@ class TemplatePairStackBlock(nn.Module):
single,
chunk_size=chunk_size,
mask=single_mask,
use_memory_efficient_kernel=not use_lma,
use_lma=use_lma,
)
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment