Commit 54457695 authored by epenning's avatar epenning
Browse files

Propagate use_flash in offloaded inference

parent 0d2dd5d0
......@@ -721,6 +721,7 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled()))
......@@ -731,6 +732,7 @@ class EvoformerStack(nn.Module):
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment