"vscode:/vscode.git/clone" did not exist on "b3e9af30462aa449576edf4018cd5ad6b4d7b5d8"
Commit 4c1cf6e2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add chunk size tuning

parent 8036a213
...@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import add, chunk_layer from openfold.utils.tensor_utils import add, chunk_layer, ChunkSizeTuner
class MSATransition(nn.Module): class MSATransition(nn.Module):
...@@ -498,6 +498,7 @@ class EvoformerStack(nn.Module): ...@@ -498,6 +498,7 @@ class EvoformerStack(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
tune_chunk_size: bool = True,
**kwargs, **kwargs,
): ):
""" """
...@@ -534,6 +535,8 @@ class EvoformerStack(nn.Module): ...@@ -534,6 +535,8 @@ class EvoformerStack(nn.Module):
clear_cache_between_blocks: clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
""" """
super(EvoformerStack, self).__init__() super(EvoformerStack, self).__init__()
...@@ -562,6 +565,11 @@ class EvoformerStack(nn.Module): ...@@ -562,6 +565,11 @@ class EvoformerStack(nn.Module):
self.linear = Linear(c_m, c_s) self.linear = Linear(c_m, c_s)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
...@@ -581,7 +589,9 @@ class EvoformerStack(nn.Module): ...@@ -581,7 +589,9 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask [*, N_seq, N_res] MSA mask
pair_mask: pair_mask:
[*, N_res, N_res] pair mask [*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference use_lma: Whether to use low-memory attention during inference
Returns: Returns:
m: m:
...@@ -604,12 +614,20 @@ class EvoformerStack(nn.Module): ...@@ -604,12 +614,20 @@ class EvoformerStack(nn.Module):
] ]
if(self.clear_cache_between_blocks): if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args): def block_with_cache_clear(block, *args, **kwargs):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return block(*args) return block(*args, **kwargs)
blocks = [partial(block_with_cache_clear, b) for b in blocks] blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
blocks_per_ckpt = self.blocks_per_ckpt blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()): if(not torch.is_grad_enabled()):
blocks_per_ckpt = None blocks_per_ckpt = None
...@@ -647,6 +665,7 @@ class ExtraMSAStack(nn.Module): ...@@ -647,6 +665,7 @@ class ExtraMSAStack(nn.Module):
ckpt: bool, ckpt: bool,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False, chunk_msa_attn: bool = False,
tune_chunk_size: bool = True,
**kwargs, **kwargs,
): ):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
...@@ -674,6 +693,11 @@ class ExtraMSAStack(nn.Module): ...@@ -674,6 +693,11 @@ class ExtraMSAStack(nn.Module):
) )
self.blocks.append(block) self.blocks.append(block)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
...@@ -712,13 +736,21 @@ class ExtraMSAStack(nn.Module): ...@@ -712,13 +736,21 @@ class ExtraMSAStack(nn.Module):
) for b in self.blocks ) for b in self.blocks
] ]
def clear_cache(b, *args): def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return b(*args) return b(*args, **kwargs)
if(self.clear_cache_between_blocks): if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks] blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
for b in blocks: for b in blocks:
if(self.ckpt and torch.is_grad_enabled()): if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z)) m, z = checkpoint_fn(b, *(m, z))
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import logging
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
...@@ -417,3 +419,74 @@ def chunk_layer( ...@@ -417,3 +419,74 @@ def chunk_layer(
out = tensor_tree_map(reshape, out) out = tensor_tree_map(reshape, out)
return out return out
class ChunkSizeTuner:
def __init__(self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=256,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if(min_chunk_size >= self.max_chunk_size):
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if(not viable):
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def tune_chunk_size(self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
arg_data = [
arg if type(arg) != torch.Tensor else arg.shape for arg in args
]
if(self.cached_arg_data is not None):
# If args have changed shape/value, we need to re-tune
assert(len(self.cached_arg_data) == len(args))
arg_data_iter = zip(self.cached_arg_data, arg_data)
for cached_arg_data, arg_data in arg_data_iter:
assert(type(cached_arg_data) == type(arg_data))
consistent = cached_arg_data == arg_data
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if(not consistent):
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment