Commit 96ca6460 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Scale back chunk size tuning a little

parent f8f74006
......@@ -332,6 +332,9 @@ class EvoformerBlock(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
inplace_safe = not (self.training or torch.is_grad_enabled())
print(chunk_size)
print(_attn_chunk_size)
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
......@@ -653,7 +656,7 @@ class EvoformerStack(nn.Module):
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2),
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
......@@ -783,7 +786,7 @@ class ExtraMSAStack(nn.Module):
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2),
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
......
......@@ -201,7 +201,11 @@ class TemplatePairStackBlock(nn.Module):
use_lma: bool = False,
_mask_trans: bool = True,
_inplace: bool = False,
_attn_chunk_size: Optional[int] = None,
):
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
]
......@@ -216,7 +220,7 @@ class TemplatePairStackBlock(nn.Module):
self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
)
......@@ -228,7 +232,7 @@ class TemplatePairStackBlock(nn.Module):
self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
)
......@@ -375,12 +379,17 @@ class TemplatePairStack(nn.Module):
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t,),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
blocks = [
partial(b,
chunk_size=chunk_size,
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
t, = checkpoint_blocks(
blocks=blocks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment