Commit df89bb28 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add chunking experiment

parent 70d6bda5
...@@ -318,10 +318,10 @@ config = mlc.ConfigDict( ...@@ -318,10 +318,10 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": True,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
}, },
"enabled": True, "enabled": True,
}, },
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Optional
...@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.tensor_utils import chunk_layer
...@@ -117,51 +118,23 @@ class MSATransition(nn.Module): ...@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
return m return m
class EvoformerBlock(nn.Module): class EvoformerBlockCore(nn.Module):
def __init__( def __init__(
self, self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int, c_hidden_opm: int,
c_hidden_mul: int, c_hidden_mul: int,
c_hidden_pair_att: int, c_hidden_pair_att: int,
no_heads_msa: int, no_heads_msa: int,
no_heads_pair: int, no_heads_pair: int,
transition_n: int, transition_n: int,
msa_dropout: float,
pair_dropout: float, pair_dropout: float,
inf: float, inf: float,
eps: float, eps: float,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
): ):
super(EvoformerBlock, self).__init__() super(EvoformerBlockCore, self).__init__()
self._is_extra_msa_stack = _is_extra_msa_stack
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
else:
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_transition = MSATransition( self.msa_transition = MSATransition(
c_m=c_m, c_m=c_m,
...@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module): ...@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
transition_n, transition_n,
) )
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
...@@ -220,10 +192,6 @@ class EvoformerBlock(nn.Module): ...@@ -220,10 +192,6 @@ class EvoformerBlock(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition( m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size m, mask=msa_trans_mask, chunk_size=chunk_size
) )
...@@ -245,6 +213,174 @@ class EvoformerBlock(nn.Module): ...@@ -245,6 +213,174 @@ class EvoformerBlock(nn.Module):
return m, z return m, z
class EvoformerBlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
):
super().__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return m, z
class ExtraMSABlock(nn.Module):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
):
super().__init__()
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
checkpoint_chunk_size: Optional[int] = 512,
) -> Tuple[torch.Tensor, torch.Tensor]:
checkpoint_chunk_size = checkpoint_chunk_size if self.ckpt else None
m = m + self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=chunk_size,
_chunk_and_checkpoint=checkpoint_chunk_size,
)
)
def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
return m, z
if(self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
else:
m, z = fn(m, z)
return m, z
class EvoformerStack(nn.Module): class EvoformerStack(nn.Module):
""" """
Main Evoformer trunk. Main Evoformer trunk.
...@@ -271,7 +407,6 @@ class EvoformerStack(nn.Module): ...@@ -271,7 +407,6 @@ class EvoformerStack(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
_is_extra_msa_stack: bool = False,
**kwargs, **kwargs,
): ):
""" """
...@@ -313,7 +448,6 @@ class EvoformerStack(nn.Module): ...@@ -313,7 +448,6 @@ class EvoformerStack(nn.Module):
self.blocks_per_ckpt = blocks_per_ckpt self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks self.clear_cache_between_blocks = clear_cache_between_blocks
self._is_extra_msa_stack = _is_extra_msa_stack
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
...@@ -332,15 +466,12 @@ class EvoformerStack(nn.Module): ...@@ -332,15 +466,12 @@ class EvoformerStack(nn.Module):
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
inf=inf, inf=inf,
eps=eps, eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
) )
self.blocks.append(block) self.blocks.append(block)
if not self._is_extra_msa_stack: self.linear = Linear(c_m, c_s)
self.linear = Linear(c_m, c_s)
def forward( def forward(self,
self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
...@@ -390,12 +521,10 @@ class EvoformerStack(nn.Module): ...@@ -390,12 +521,10 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
s = None seq_dim = -3
if not self._is_extra_msa_stack: index = torch.tensor([0], device=m.device)
seq_dim = -3 s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
index = torch.tensor([0], device=m.device) s = s.squeeze(seq_dim)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)
return m, z, s return m, z, s
...@@ -405,8 +534,7 @@ class ExtraMSAStack(nn.Module): ...@@ -405,8 +534,7 @@ class ExtraMSAStack(nn.Module):
Implements Algorithm 18. Implements Algorithm 18.
""" """
def __init__( def __init__(self,
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -419,38 +547,38 @@ class ExtraMSAStack(nn.Module): ...@@ -419,38 +547,38 @@ class ExtraMSAStack(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
blocks_per_ckpt: int,
inf: float, inf: float,
eps: float, eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
**kwargs, **kwargs,
): ):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
c_s = None self.clear_cache_between_blocks = clear_cache_between_blocks
self.stack = EvoformerStack(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
c_s=c_s,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
no_blocks=no_blocks,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
blocks_per_ckpt=blocks_per_ckpt,
inf=inf,
eps=eps,
clear_cache_between_blocks=clear_cache_between_blocks,
_is_extra_msa_stack=True,
)
def forward( self.blocks = nn.ModuleList()
self,
for _ in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=ckpt,
)
self.blocks.append(block)
def forward(self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
...@@ -470,13 +598,11 @@ class ExtraMSAStack(nn.Module): ...@@ -470,13 +598,11 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask Optional [*, N_res, N_res] pair mask
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
_, z, _ = self.stack( for b in self.blocks:
m, m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
z,
msa_mask=msa_mask, if(self.clear_cache_between_blocks):
pair_mask=pair_mask, torch.cuda.empty_cache()
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return z return z
...@@ -336,7 +336,9 @@ class AlphaFold(nn.Module): ...@@ -336,7 +336,9 @@ class AlphaFold(nn.Module):
def _disable_activation_checkpointing(self): def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None self.evoformer.blocks_per_ckpt = None
self.extra_msa_stack.stack.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self): def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = ( self.template_pair_stack.blocks_per_ckpt = (
...@@ -345,9 +347,9 @@ class AlphaFold(nn.Module): ...@@ -345,9 +347,9 @@ class AlphaFold(nn.Module):
self.evoformer.blocks_per_ckpt = ( self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt self.config.evoformer_stack.blocks_per_ckpt
) )
self.extra_msa_stack.stack.blocks_per_ckpt = (
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt for b in self.extra_msa_stack.blocks:
) b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch): def forward(self, batch):
""" """
......
...@@ -93,6 +93,7 @@ class MSAAttention(nn.Module): ...@@ -93,6 +93,7 @@ class MSAAttention(nn.Module):
z: Optional[torch.Tensor] = None, z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -125,9 +126,9 @@ class MSAAttention(nn.Module): ...@@ -125,9 +126,9 @@ class MSAAttention(nn.Module):
# This step simply returns a larger view of the bias, and does not # This step simply returns a larger view of the bias, and does not
# consume additional memory. # consume additional memory.
# [*, N_seq, no_heads, N_res, N_res] # [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand( #bias = bias.expand(
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
) #)
biases = [bias] biases = [bias]
...@@ -150,7 +151,13 @@ class MSAAttention(nn.Module): ...@@ -150,7 +151,13 @@ class MSAAttention(nn.Module):
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, biases, chunk_size) m = self._chunk(m, biases, chunk_size)
else: else:
m = self.mha(q_x=m, k_x=m, v_x=m, biases=biases) m = self.mha(
q_x=m,
k_x=m,
v_x=m,
biases=biases,
_chunk_and_checkpoint=_chunk_and_checkpoint
)
return m return m
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
...@@ -21,6 +22,7 @@ import torch ...@@ -21,6 +22,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
...@@ -164,6 +166,67 @@ class Linear(nn.Linear): ...@@ -164,6 +166,67 @@ class Linear(nn.Linear):
raise ValueError("Invalid init string.") raise ValueError("Invalid init string.")
def _attention(query, key, value, biases):
a = torch.matmul(query, key)
for b in biases:
a += b
a = torch.nn.functional.softmax(a, dim=-1)
# [*, H, Q, C_hidden]
o = torch.matmul(a, value)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
return o
@torch.jit.ignore
def _attention_chunk_and_checkpoint(query, key, value, biases, chunk_size):
if(len(biases) > 2):
raise ValueError(
"_chunk_and_checkpoint only permits two bias terms"
)
biases = biases + [None, None]
bias_1, bias_2 = biases[:2]
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b1, b2]
return _attention(q, k, v, bs)
batch_dims = query.shape[:-3]
no_batch_dims = len(query.shape[:-3])
# q, k, and v are assumed to have no singleton dimensions
flat_q = query.reshape(-1, *query.shape[-3:])
flat_k = key.reshape(-1, *key.shape[-3:])
flat_v = value.reshape(-1, *value.shape[-3:])
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
count = flat_q.shape[0]
for start in range(0, count, chunk_size):
end = start + chunk_size
q_chunk = flat_q[start: end, ...]
k_chunk = flat_k[start: end, ...]
v_chunk = flat_v[start: end, ...]
bias_1_chunk = _chunk_slice(bias_1, start, end, no_batch_dims)
bias_2_chunk = _chunk_slice(bias_2, start, end, no_batch_dims)
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
)
o_chunks.append(o_chunk)
o_flat = torch.cat(o_chunks, dim=0)
return o_flat.reshape(batch_dims + o_flat.shape[1:])
class Attention(nn.Module): class Attention(nn.Module):
""" """
Standard multi-head attention using AlphaFold's default layer Standard multi-head attention using AlphaFold's default layer
...@@ -225,7 +288,6 @@ class Attention(nn.Module): ...@@ -225,7 +288,6 @@ class Attention(nn.Module):
) )
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward( def forward(
self, self,
...@@ -233,6 +295,10 @@ class Attention(nn.Module): ...@@ -233,6 +295,10 @@ class Attention(nn.Module):
k_x: torch.Tensor, k_x: torch.Tensor,
v_x: torch.Tensor, v_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -245,6 +311,18 @@ class Attention(nn.Module): ...@@ -245,6 +311,18 @@ class Attention(nn.Module):
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
if(biases is None):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
if(use_lma and _chunk_and_checkpoint is not None):
raise ValueError(
"use_lma and _chunk_and_checkpoint are mutually exclusive"
)
# [*, Q/K/V, H * C_hidden] # [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x) q = self.linear_q(q_x)
k = self.linear_k(k_x) k = self.linear_k(k_x)
...@@ -255,34 +333,33 @@ class Attention(nn.Module): ...@@ -255,34 +333,33 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q, C_hidden] q = q / math.sqrt(self.c_hidden)
q = permute_final_dims(q, (1, 0, 2))
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, Q, K]
a = torch.matmul(q, k)
del q, k
norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm
if biases is not None: if(use_lma):
for b in biases: biases = [
a += b b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
a = self.softmax(a) ]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
# [*, H, V, C_hidden] else:
v = permute_final_dims(v, (1, 0, 2)) # [*, H, Q, C_hidden]
q = permute_final_dims(q, (1, 0, 2))
# [*, H, Q, C_hidden]
o = torch.matmul(a, v) # [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, V, C_hidden]
v = permute_final_dims(v, (1, 0, 2))
if(_chunk_and_checkpoint):
# REMEMBER THAT THE K, Q, V COMPUTATION AND GATING ARE *NOT*
# CHECKPOINTED HERE
o = _attention_chunk_and_checkpoint(
q, k, v, biases, _chunk_and_checkpoint
)
else:
o = _attention(q, k, v, biases)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
if(self.linear_g is not None): if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x)) g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
...@@ -374,14 +451,13 @@ class GlobalAttention(nn.Module): ...@@ -374,14 +451,13 @@ class GlobalAttention(nn.Module):
return m return m
@torch.jit.script
def _lma( def _lma(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
q_chunk_size: int, q_chunk_size: int,
kv_chunk_size: int kv_chunk_size: int,
): ):
no_q, no_kv = q.shape[-3], k.shape[-3] no_q, no_kv = q.shape[-3], k.shape[-3]
...@@ -389,34 +465,34 @@ def _lma( ...@@ -389,34 +465,34 @@ def _lma(
o = q.new_zeros(q.shape) o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size): for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :] q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
big_bias_chunks = [ large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases b[..., q_s: q_s + q_chunk_size, :] for b in biases
] ]
maxes = [] maxes = []
weights = [] weights = []
values = [] values = []
for kv_s in range(0, no_kv, kv_chunk_size): for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :] k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :] v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
small_bias_chunks = [ small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in big_bias_chunks b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
] ]
a = torch.einsum( a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk "...qhd,...khd->...hqk", query, key
) )
for b in small_bias_chunks: for b in small_bias_chunks:
a += b a += b
a = a.transpose(-2, -3) a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0].detach() max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a) exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) exp_v = torch.einsum("...vhf,...qhv->...qhf", value, exp_a)
maxes.append(max_a.squeeze(-1)) maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1)) weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v) values.append(exp_v)
...@@ -437,111 +513,3 @@ def _lma( ...@@ -437,111 +513,3 @@ def _lma(
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
return o return o
class LowMemoryAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_chunk_size: int,
kv_chunk_size: int,
biases: Optional[List[torch.Tensor]] = None,
):
if(biases is None):
biases = []
else:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(q.shape[-1])
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partialmethod from functools import partialmethod, partial
import math import math
from typing import Optional, List from typing import Optional, List
...@@ -70,7 +70,7 @@ class TriangleAttention(nn.Module): ...@@ -70,7 +70,7 @@ class TriangleAttention(nn.Module):
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
self.mha, partial(self.mha),
mha_inputs, mha_inputs,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]), no_batch_dims=len(x.shape[:-2]),
......
...@@ -15,17 +15,27 @@ ...@@ -15,17 +15,27 @@
import deepspeed import deepspeed
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable from typing import Any, Tuple, List, Callable, Optional
BLOCK_ARG = Any BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG] BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn():
if(deepspeed.checkpointing.is_configured()):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
return checkpoint
@torch.jit.ignore @torch.jit.ignore
def checkpoint_blocks( def checkpoint_blocks(
blocks: List[Callable], blocks: List[Callable],
args: BLOCK_ARGS, args: BLOCK_ARGS,
blocks_per_ckpt: int, blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS: ) -> BLOCK_ARGS:
""" """
Chunk a list of blocks and run each chunk with activation Chunk a list of blocks and run each chunk with activation
...@@ -68,10 +78,7 @@ def checkpoint_blocks( ...@@ -68,10 +78,7 @@ def checkpoint_blocks(
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
if(deepspeed.checkpointing.is_configured()): checkpoint = get_checkpoint_fn()
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment