Commit 54457695 authored by epenning's avatar epenning
Browse files

Propagate use_flash in offloaded inference

parent 0d2dd5d0
...@@ -721,6 +721,7 @@ class EvoformerStack(nn.Module): ...@@ -721,6 +721,7 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled())) assert(not (self.training or torch.is_grad_enabled()))
...@@ -731,6 +732,7 @@ class EvoformerStack(nn.Module): ...@@ -731,6 +732,7 @@ class EvoformerStack(nn.Module):
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
inplace_safe=True, inplace_safe=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment