Commit 748a8c50 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add DeepSpeed check to checkpointing function

parent fcf9e2f8
...@@ -19,7 +19,6 @@ from typing import Tuple, Optional ...@@ -19,7 +19,6 @@ from typing import Tuple, Optional
from functools import partial from functools import partial
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
from openfold.model.msa import ( from openfold.model.msa import (
MSARowAttentionWithPairBias, MSARowAttentionWithPairBias,
...@@ -36,6 +35,7 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -36,6 +35,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.tensor_utils import chunk_layer
......
...@@ -19,7 +19,6 @@ import torch ...@@ -19,7 +19,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, Attention
from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
DropoutColumnwise, DropoutColumnwise,
...@@ -33,6 +32,7 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -33,6 +32,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import deepspeed
import torch
import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable
BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG]
def checkpoint_blocks(
blocks: List[Callable],
args: BLOCK_ARGS,
blocks_per_ckpt: int,
) -> BLOCK_ARGS:
"""
Chunk a list of blocks and run each chunk with activation
checkpointing. We define a "block" as a callable whose only inputs are
the outputs of the previous block.
This function assumes that deepspeed has already been initialized.
Implements Subsection 1.11.8
Args:
blocks:
List of blocks
args:
Tuple of arguments for the first block.
blocks_per_ckpt:
Size of each chunk. A higher value corresponds to higher memory
consumption but fewer checkpoints. If None, no checkpointing is
performed.
Returns:
The output of the final block
"""
def wrap(a):
return (a,) if type(a) is not tuple else a
def exec(b, a):
for block in b:
a = wrap(block(*a))
return a
def chunker(s, e):
def exec_sliced(*a):
return exec(blocks[s:e], a)
return exec_sliced
# Avoids mishaps when the blocks take just one argument
args = wrap(args)
if blocks_per_ckpt is None:
return exec(blocks, args)
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
if(deepspeed.checkpointing.is_configured()):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
args = checkpoint(chunker(s, e), *args)
args = wrap(args)
return args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment