Commit df89bb28 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add chunking experiment

parent 70d6bda5
......@@ -318,10 +318,10 @@ config = mlc.ConfigDict(
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from typing import Tuple, Optional
......@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import chunk_layer
......@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
return m
class EvoformerBlock(nn.Module):
class EvoformerBlockCore(nn.Module):
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
):
super(EvoformerBlock, self).__init__()
self._is_extra_msa_stack = _is_extra_msa_stack
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
else:
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
super(EvoformerBlockCore, self).__init__()
self.msa_transition = MSATransition(
c_m=c_m,
......@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
transition_n,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
......@@ -220,10 +192,6 @@ class EvoformerBlock(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
......@@ -245,6 +213,174 @@ class EvoformerBlock(nn.Module):
return m, z
class EvoformerBlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
):
super().__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return m, z
class ExtraMSABlock(nn.Module):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
):
super().__init__()
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
checkpoint_chunk_size: Optional[int] = 512,
) -> Tuple[torch.Tensor, torch.Tensor]:
checkpoint_chunk_size = checkpoint_chunk_size if self.ckpt else None
m = m + self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=chunk_size,
_chunk_and_checkpoint=checkpoint_chunk_size,
)
)
def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
return m, z
if(self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
else:
m, z = fn(m, z)
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
......@@ -271,7 +407,6 @@ class EvoformerStack(nn.Module):
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
_is_extra_msa_stack: bool = False,
**kwargs,
):
"""
......@@ -313,7 +448,6 @@ class EvoformerStack(nn.Module):
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self._is_extra_msa_stack = _is_extra_msa_stack
self.blocks = nn.ModuleList()
......@@ -332,15 +466,12 @@ class EvoformerStack(nn.Module):
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
)
self.blocks.append(block)
if not self._is_extra_msa_stack:
self.linear = Linear(c_m, c_s)
def forward(
self,
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
......@@ -390,8 +521,6 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = None
if not self._is_extra_msa_stack:
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
......@@ -405,8 +534,7 @@ class ExtraMSAStack(nn.Module):
Implements Algorithm 18.
"""
def __init__(
self,
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -419,38 +547,38 @@ class ExtraMSAStack(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
inf: float,
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
c_s = None
self.stack = EvoformerStack(
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
c_s=c_s,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
no_blocks=no_blocks,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
blocks_per_ckpt=blocks_per_ckpt,
inf=inf,
eps=eps,
clear_cache_between_blocks=clear_cache_between_blocks,
_is_extra_msa_stack=True,
ckpt=ckpt,
)
self.blocks.append(block)
def forward(
self,
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
......@@ -471,12 +599,10 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
_, z, _ = self.stack(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
......@@ -336,7 +336,9 @@ class AlphaFold(nn.Module):
def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
self.extra_msa_stack.stack.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = (
......@@ -345,9 +347,9 @@ class AlphaFold(nn.Module):
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
self.extra_msa_stack.stack.blocks_per_ckpt = (
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
)
for b in self.extra_msa_stack.blocks:
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch):
"""
......
......@@ -93,6 +93,7 @@ class MSAAttention(nn.Module):
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor:
"""
Args:
......@@ -125,9 +126,9 @@ class MSAAttention(nn.Module):
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand(
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
biases = [bias]
......@@ -150,7 +151,13 @@ class MSAAttention(nn.Module):
if chunk_size is not None:
m = self._chunk(m, biases, chunk_size)
else:
m = self.mha(q_x=m, k_x=m, v_x=m, biases=biases)
m = self.mha(
q_x=m,
k_x=m,
v_x=m,
biases=biases,
_chunk_and_checkpoint=_chunk_and_checkpoint
)
return m
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
......@@ -21,6 +22,7 @@ import torch
import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
......@@ -164,6 +166,67 @@ class Linear(nn.Linear):
raise ValueError("Invalid init string.")
def _attention(query, key, value, biases):
a = torch.matmul(query, key)
for b in biases:
a += b
a = torch.nn.functional.softmax(a, dim=-1)
# [*, H, Q, C_hidden]
o = torch.matmul(a, value)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
return o
@torch.jit.ignore
def _attention_chunk_and_checkpoint(query, key, value, biases, chunk_size):
if(len(biases) > 2):
raise ValueError(
"_chunk_and_checkpoint only permits two bias terms"
)
biases = biases + [None, None]
bias_1, bias_2 = biases[:2]
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b1, b2]
return _attention(q, k, v, bs)
batch_dims = query.shape[:-3]
no_batch_dims = len(query.shape[:-3])
# q, k, and v are assumed to have no singleton dimensions
flat_q = query.reshape(-1, *query.shape[-3:])
flat_k = key.reshape(-1, *key.shape[-3:])
flat_v = value.reshape(-1, *value.shape[-3:])
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
count = flat_q.shape[0]
for start in range(0, count, chunk_size):
end = start + chunk_size
q_chunk = flat_q[start: end, ...]
k_chunk = flat_k[start: end, ...]
v_chunk = flat_v[start: end, ...]
bias_1_chunk = _chunk_slice(bias_1, start, end, no_batch_dims)
bias_2_chunk = _chunk_slice(bias_2, start, end, no_batch_dims)
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
)
o_chunks.append(o_chunk)
o_flat = torch.cat(o_chunks, dim=0)
return o_flat.reshape(batch_dims + o_flat.shape[1:])
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
......@@ -225,7 +288,6 @@ class Attention(nn.Module):
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(
self,
......@@ -233,6 +295,10 @@ class Attention(nn.Module):
k_x: torch.Tensor,
v_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor:
"""
Args:
......@@ -245,6 +311,18 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
if(biases is None):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
if(use_lma and _chunk_and_checkpoint is not None):
raise ValueError(
"use_lma and _chunk_and_checkpoint are mutually exclusive"
)
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
......@@ -255,34 +333,33 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(self.c_hidden)
if(use_lma):
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else:
# [*, H, Q, C_hidden]
q = permute_final_dims(q, (1, 0, 2))
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, Q, K]
a = torch.matmul(q, k)
del q, k
norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm
if biases is not None:
for b in biases:
a += b
a = self.softmax(a)
# [*, H, V, C_hidden]
v = permute_final_dims(v, (1, 0, 2))
# [*, H, Q, C_hidden]
o = torch.matmul(a, v)
if(_chunk_and_checkpoint):
# REMEMBER THAT THE K, Q, V COMPUTATION AND GATING ARE *NOT*
# CHECKPOINTED HERE
o = _attention_chunk_and_checkpoint(
q, k, v, biases, _chunk_and_checkpoint
)
else:
o = _attention(q, k, v, biases)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
......@@ -374,14 +451,13 @@ class GlobalAttention(nn.Module):
return m
@torch.jit.script
def _lma(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
q_chunk_size: int,
kv_chunk_size: int
kv_chunk_size: int,
):
no_q, no_kv = q.shape[-3], k.shape[-3]
......@@ -389,7 +465,7 @@ def _lma(
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
big_bias_chunks = [
large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]
......@@ -400,11 +476,11 @@ def _lma(
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in big_bias_chunks
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
]
a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk
"...qhd,...khd->...hqk", query, key
)
for b in small_bias_chunks:
......@@ -412,11 +488,11 @@ def _lma(
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0].detach()
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", value, exp_a)
maxes.append(max_a.squeeze(-1))
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v)
......@@ -437,111 +513,3 @@ def _lma(
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
return o
class LowMemoryAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_chunk_size: int,
kv_chunk_size: int,
biases: Optional[List[torch.Tensor]] = None,
):
if(biases is None):
biases = []
else:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(q.shape[-1])
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod
from functools import partialmethod, partial
import math
from typing import Optional, List
......@@ -70,7 +70,7 @@ class TriangleAttention(nn.Module):
"biases": biases,
}
return chunk_layer(
self.mha,
partial(self.mha),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
......
......@@ -15,17 +15,27 @@
import deepspeed
import torch
import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable
from typing import Any, Tuple, List, Callable, Optional
BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn():
if(deepspeed.checkpointing.is_configured()):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
return checkpoint
@torch.jit.ignore
def checkpoint_blocks(
blocks: List[Callable],
args: BLOCK_ARGS,
blocks_per_ckpt: int,
blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS:
"""
Chunk a list of blocks and run each chunk with activation
......@@ -68,10 +78,7 @@ def checkpoint_blocks(
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
if(deepspeed.checkpointing.is_configured()):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
checkpoint = get_checkpoint_fn()
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment