Commit 96ca6460 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Scale back chunk size tuning a little

parent f8f74006
...@@ -332,6 +332,9 @@ class EvoformerBlock(nn.Module): ...@@ -332,6 +332,9 @@ class EvoformerBlock(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
print(chunk_size)
print(_attn_chunk_size)
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
...@@ -653,7 +656,7 @@ class EvoformerStack(nn.Module): ...@@ -653,7 +656,7 @@ class EvoformerStack(nn.Module):
chunk_size=tuned_chunk_size, chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional # A temporary measure to address torch's occasional
# inability to allocate large tensors # inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2), _attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks ) for b in blocks
] ]
...@@ -783,7 +786,7 @@ class ExtraMSAStack(nn.Module): ...@@ -783,7 +786,7 @@ class ExtraMSAStack(nn.Module):
chunk_size=tuned_chunk_size, chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional # A temporary measure to address torch's occasional
# inability to allocate large tensors # inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2), _attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks ) for b in blocks
] ]
......
...@@ -201,7 +201,11 @@ class TemplatePairStackBlock(nn.Module): ...@@ -201,7 +201,11 @@ class TemplatePairStackBlock(nn.Module):
use_lma: bool = False, use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_inplace: bool = False, _inplace: bool = False,
_attn_chunk_size: Optional[int] = None,
): ):
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
single_templates = [ single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4) t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
] ]
...@@ -216,7 +220,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -216,7 +220,7 @@ class TemplatePairStackBlock(nn.Module):
self.dropout_row( self.dropout_row(
self.tri_att_start( self.tri_att_start(
single, single,
chunk_size=chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_lma=use_lma, use_lma=use_lma,
) )
...@@ -228,7 +232,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -228,7 +232,7 @@ class TemplatePairStackBlock(nn.Module):
self.dropout_col( self.dropout_col(
self.tri_att_end( self.tri_att_end(
single, single,
chunk_size=chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_lma=use_lma, use_lma=use_lma,
) )
...@@ -375,12 +379,17 @@ class TemplatePairStack(nn.Module): ...@@ -375,12 +379,17 @@ class TemplatePairStack(nn.Module):
] ]
if(chunk_size is not None and self.chunk_size_tuner is not None): if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
args=(t,), args=(t,),
min_chunk_size=chunk_size, min_chunk_size=chunk_size,
) )
blocks = [partial(b, chunk_size=chunk_size) for b in blocks] blocks = [
partial(b,
chunk_size=chunk_size,
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
t, = checkpoint_blocks( t, = checkpoint_blocks(
blocks=blocks, blocks=blocks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment