Commit d6b36a80 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Tweak chunk tuning a little

parent 29f2ffe0
......@@ -185,6 +185,7 @@ class EvoformerBlockCore(nn.Module):
chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
......@@ -192,6 +193,9 @@ class EvoformerBlockCore(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
# Need to dodge activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
......@@ -240,7 +244,7 @@ class EvoformerBlockCore(nn.Module):
self.tri_att_start(
z,
mask=pair_mask,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
use_lma=use_lma
)
),
......@@ -251,7 +255,7 @@ class EvoformerBlockCore(nn.Module):
self.tri_att_end(
z,
mask=pair_mask,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
)
),
......@@ -324,21 +328,33 @@ class EvoformerBlock(nn.Module):
chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
)
),
inplace=inplace_safe,
)
m = m + self.msa_att_col(
m = add(m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
),
inplace=inplace_safe,
)
m, z = self.core(
m,
......@@ -348,6 +364,7 @@ class EvoformerBlock(nn.Module):
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
)
return m, z
......@@ -421,7 +438,11 @@ class ExtraMSABlock(nn.Module):
use_lma: bool = False,
_chunk_logits: Optional[int] = 1024,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
# If function calls could speak...
m = add(m,
self.msa_dropout_layer(
......@@ -429,7 +450,7 @@ class ExtraMSABlock(nn.Module):
m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma,
_chunk_logits=
......@@ -459,6 +480,7 @@ class ExtraMSABlock(nn.Module):
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size
)
return m, z
......@@ -621,12 +643,19 @@ class EvoformerStack(nn.Module):
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2),
) for b in blocks
]
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
......@@ -744,12 +773,19 @@ class ExtraMSAStack(nn.Module):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2),
) for b in blocks
]
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
......
......@@ -41,6 +41,7 @@ from openfold.utils.feats import (
from openfold.utils.tensor_utils import (
add,
chunk_layer,
ChunkSizeTuner,
permute_final_dims,
flatten_final_dims,
tensor_tree_map,
......@@ -293,6 +294,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9,
**kwargs,
):
......@@ -333,6 +335,11 @@ class TemplatePairStack(nn.Module):
self.layer_norm = LayerNorm(c_t)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(
self,
t: torch.tensor,
......@@ -355,8 +362,7 @@ class TemplatePairStack(nn.Module):
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
blocks = [
partial(
b,
mask=mask,
......@@ -366,7 +372,18 @@ class TemplatePairStack(nn.Module):
_inplace=not (self.training or torch.is_grad_enabled()),
)
for b in self.blocks
],
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t,),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
t, = checkpoint_blocks(
blocks=blocks,
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment