Unverified Commit 984370ce authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #178 from epenning/fix_offload_flash

Fix propagation of use_flash for offloaded inference
parents 965961af 54457695
...@@ -721,6 +721,7 @@ class EvoformerStack(nn.Module): ...@@ -721,6 +721,7 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled())) assert(not (self.training or torch.is_grad_enabled()))
...@@ -731,6 +732,7 @@ class EvoformerStack(nn.Module): ...@@ -731,6 +732,7 @@ class EvoformerStack(nn.Module):
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
inplace_safe=True, inplace_safe=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment