"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "a2773f3e1d438e2f98d8e19c744520f9ea39f99f"
Commit d6b36a80 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Tweak chunk tuning a little

parent 29f2ffe0
...@@ -185,12 +185,16 @@ class EvoformerBlockCore(nn.Module): ...@@ -185,12 +185,16 @@ class EvoformerBlockCore(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of # should be disabled to better approximate the exact activations of
# the original. # the original.
msa_trans_mask = msa_mask if _mask_trans else None msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
# Need to dodge activation checkpoints # Need to dodge activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
...@@ -240,7 +244,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -240,7 +244,7 @@ class EvoformerBlockCore(nn.Module):
self.tri_att_start( self.tri_att_start(
z, z,
mask=pair_mask, mask=pair_mask,
chunk_size=chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma use_lma=use_lma
) )
), ),
...@@ -251,7 +255,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -251,7 +255,7 @@ class EvoformerBlockCore(nn.Module):
self.tri_att_end( self.tri_att_end(
z, z,
mask=pair_mask, mask=pair_mask,
chunk_size=chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
...@@ -324,21 +328,33 @@ class EvoformerBlock(nn.Module): ...@@ -324,21 +328,33 @@ class EvoformerBlock(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer( inplace_safe = not (self.training or torch.is_grad_enabled())
self.msa_att_row(
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
)
),
inplace=inplace_safe,
)
m = add(m,
self.msa_att_col(
m, m,
z=z,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
) ),
) inplace=inplace_safe,
m = m + self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
) )
m, z = self.core( m, z = self.core(
m, m,
...@@ -348,6 +364,7 @@ class EvoformerBlock(nn.Module): ...@@ -348,6 +364,7 @@ class EvoformerBlock(nn.Module):
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
) )
return m, z return m, z
...@@ -421,7 +438,11 @@ class ExtraMSABlock(nn.Module): ...@@ -421,7 +438,11 @@ class ExtraMSABlock(nn.Module):
use_lma: bool = False, use_lma: bool = False,
_chunk_logits: Optional[int] = 1024, _chunk_logits: Optional[int] = 1024,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
# If function calls could speak... # If function calls could speak...
m = add(m, m = add(m,
self.msa_dropout_layer( self.msa_dropout_layer(
...@@ -429,7 +450,7 @@ class ExtraMSABlock(nn.Module): ...@@ -429,7 +450,7 @@ class ExtraMSABlock(nn.Module):
m.clone() if torch.is_grad_enabled() else m, m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z, z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma, use_memory_efficient_kernel=not _chunk_logits and not use_lma,
_chunk_logits= _chunk_logits=
...@@ -459,6 +480,7 @@ class ExtraMSABlock(nn.Module): ...@@ -459,6 +480,7 @@ class ExtraMSABlock(nn.Module):
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size
) )
return m, z return m, z
...@@ -621,12 +643,19 @@ class EvoformerStack(nn.Module): ...@@ -621,12 +643,19 @@ class EvoformerStack(nn.Module):
blocks = [partial(block_with_cache_clear, b) for b in blocks] blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None): if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
args=(m,z), args=(m,z),
min_chunk_size=chunk_size, min_chunk_size=chunk_size,
) )
blocks = [partial(b, chunk_size=chunk_size) for b in blocks] blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2),
) for b in blocks
]
blocks_per_ckpt = self.blocks_per_ckpt blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()): if(not torch.is_grad_enabled()):
...@@ -744,12 +773,19 @@ class ExtraMSAStack(nn.Module): ...@@ -744,12 +773,19 @@ class ExtraMSAStack(nn.Module):
blocks = [partial(clear_cache, b) for b in blocks] blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None): if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
args=(m,z), args=(m,z),
min_chunk_size=chunk_size, min_chunk_size=chunk_size,
) )
blocks = [partial(b, chunk_size=chunk_size) for b in blocks] blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 2),
) for b in blocks
]
for b in blocks: for b in blocks:
if(self.ckpt and torch.is_grad_enabled()): if(self.ckpt and torch.is_grad_enabled()):
......
...@@ -41,6 +41,7 @@ from openfold.utils.feats import ( ...@@ -41,6 +41,7 @@ from openfold.utils.feats import (
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add, add,
chunk_layer, chunk_layer,
ChunkSizeTuner,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
tensor_tree_map, tensor_tree_map,
...@@ -293,6 +294,7 @@ class TemplatePairStack(nn.Module): ...@@ -293,6 +294,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
blocks_per_ckpt, blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9, inf=1e9,
**kwargs, **kwargs,
): ):
...@@ -333,6 +335,11 @@ class TemplatePairStack(nn.Module): ...@@ -333,6 +335,11 @@ class TemplatePairStack(nn.Module):
self.layer_norm = LayerNorm(c_t) self.layer_norm = LayerNorm(c_t)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward( def forward(
self, self,
t: torch.tensor, t: torch.tensor,
...@@ -355,18 +362,28 @@ class TemplatePairStack(nn.Module): ...@@ -355,18 +362,28 @@ class TemplatePairStack(nn.Module):
expand_idx[-3] = t.shape[-4] expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx) mask = mask.expand(*expand_idx)
blocks = [
partial(
b,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
_inplace=not (self.training or torch.is_grad_enabled()),
)
for b in self.blocks
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t,),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
t, = checkpoint_blocks( t, = checkpoint_blocks(
blocks=[ blocks=blocks,
partial(
b,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
_inplace=not (self.training or torch.is_grad_enabled()),
)
for b in self.blocks
],
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment