"examples/ScanNet/data.py" did not exist on "a3a079efc2ef1dcd83a4ba9dfa395e52506814ba"
Commit 4c1cf6e2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add chunk size tuning

parent 8036a213
......@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import add, chunk_layer
from openfold.utils.tensor_utils import add, chunk_layer, ChunkSizeTuner
class MSATransition(nn.Module):
......@@ -498,6 +498,7 @@ class EvoformerStack(nn.Module):
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
tune_chunk_size: bool = True,
**kwargs,
):
"""
......@@ -534,6 +535,8 @@ class EvoformerStack(nn.Module):
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
"""
super(EvoformerStack, self).__init__()
......@@ -562,6 +565,11 @@ class EvoformerStack(nn.Module):
self.linear = Linear(c_m, c_s)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
......@@ -581,7 +589,9 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
......@@ -590,7 +600,7 @@ class EvoformerStack(nn.Module):
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
"""
blocks = [
partial(
b,
......@@ -604,12 +614,20 @@ class EvoformerStack(nn.Module):
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
def block_with_cache_clear(block, *args, **kwargs):
torch.cuda.empty_cache()
return block(*args)
return block(*args, **kwargs)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
......@@ -647,6 +665,7 @@ class ExtraMSAStack(nn.Module):
ckpt: bool,
clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False,
tune_chunk_size: bool = True,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
......@@ -673,6 +692,11 @@ class ExtraMSAStack(nn.Module):
ckpt=ckpt if chunk_msa_attn else False,
)
self.blocks.append(block)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(self,
m: torch.Tensor,
......@@ -712,13 +736,21 @@ class ExtraMSAStack(nn.Module):
) for b in self.blocks
]
def clear_cache(b, *args):
def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache()
return b(*args)
return b(*args, **kwargs)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z))
......
......@@ -14,6 +14,8 @@
# limitations under the License.
from functools import partial
import logging
import math
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
......@@ -417,3 +419,74 @@ def chunk_layer(
out = tensor_tree_map(reshape, out)
return out
class ChunkSizeTuner:
def __init__(self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=256,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if(min_chunk_size >= self.max_chunk_size):
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if(not viable):
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def tune_chunk_size(self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
arg_data = [
arg if type(arg) != torch.Tensor else arg.shape for arg in args
]
if(self.cached_arg_data is not None):
# If args have changed shape/value, we need to re-tune
assert(len(self.cached_arg_data) == len(args))
arg_data_iter = zip(self.cached_arg_data, arg_data)
for cached_arg_data, arg_data in arg_data_iter:
assert(type(cached_arg_data) == type(arg_data))
consistent = cached_arg_data == arg_data
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if(not consistent):
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment