Commit bebbe782 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Speed up inference

parent 6e66b218
...@@ -177,10 +177,9 @@ class AlphaFold(nn.Module): ...@@ -177,10 +177,9 @@ class AlphaFold(nn.Module):
t, t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
) )
if(inplace_safe): if(inplace_safe):
t *= (torch.sum(batch["template_mask"], dim=-1) > 0) t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else: else:
...@@ -199,6 +198,8 @@ class AlphaFold(nn.Module): ...@@ -199,6 +198,8 @@ class AlphaFold(nn.Module):
ret.update({"template_pair_embedding": t}) ret.update({"template_pair_embedding": t})
del t
return ret return ret
def iteration(self, feats, prevs, _recycle=True): def iteration(self, feats, prevs, _recycle=True):
......
...@@ -102,7 +102,8 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -102,7 +102,8 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor, t: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None, template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, # This module suffers greatly from a small chunk size
chunk_size: Optional[int] = 256,
use_lma: bool = False, use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -209,13 +210,14 @@ class TemplatePairStackBlock(nn.Module): ...@@ -209,13 +210,14 @@ class TemplatePairStackBlock(nn.Module):
for i in range(len(single_templates)): for i in range(len(single_templates)):
single = single_templates[i] single = single_templates[i]
single_mask = single_templates_masks[i] single_mask = single_templates_masks[i]
single = add(single, single = add(single,
self.dropout_row( self.dropout_row(
self.tri_att_start( self.tri_att_start(
single, single,
chunk_size=chunk_size, chunk_size=chunk_size,
mask=single_mask, mask=single_mask,
use_memory_efficient_kernel=not use_lma,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
...@@ -228,6 +230,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -228,6 +230,7 @@ class TemplatePairStackBlock(nn.Module):
single, single,
chunk_size=chunk_size, chunk_size=chunk_size,
mask=single_mask, mask=single_mask,
use_memory_efficient_kernel=not use_lma,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment