Unverified Commit 984370ce authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #178 from epenning/fix_offload_flash

Fix propagation of use_flash for offloaded inference
parents 965961af 54457695
......@@ -721,6 +721,7 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled()))
......@@ -731,6 +732,7 @@ class EvoformerStack(nn.Module):
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment