Commit fe9ad07e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add offloading to evoformer

parent b40fab25
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Sequence, Optional
from functools import partial from functools import partial
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
...@@ -29,6 +29,7 @@ from openfold.model.msa import ( ...@@ -29,6 +29,7 @@ from openfold.model.msa import (
from openfold.model.outer_product_mean import OuterProductMean from openfold.model.outer_product_mean import OuterProductMean
from openfold.model.pair_transition import PairTransition from openfold.model.pair_transition import PairTransition
from openfold.model.triangular_attention import ( from openfold.model.triangular_attention import (
TriangleAttention,
TriangleAttentionStartingNode, TriangleAttentionStartingNode,
TriangleAttentionEndingNode, TriangleAttentionEndingNode,
) )
...@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import add, chunk_layer, ChunkSizeTuner from openfold.utils.chunk_utils import chunk_layer, ChunkSizeTuner
from openfold.utils.tensor_utils import add
class MSATransition(nn.Module): class MSATransition(nn.Module):
...@@ -66,6 +68,7 @@ class MSATransition(nn.Module): ...@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask): def _transition(self, m, mask):
m = self.layer_norm(m)
m = self.linear_1(m) m = self.linear_1(m)
m = self.relu(m) m = self.relu(m)
m = self.linear_2(m) * mask m = self.linear_2(m) * mask
...@@ -107,8 +110,6 @@ class MSATransition(nn.Module): ...@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
m = self.layer_norm(m)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, mask, chunk_size) m = self._chunk(m, mask, chunk_size)
else: else:
...@@ -155,13 +156,13 @@ class EvoformerBlockCore(nn.Module): ...@@ -155,13 +156,13 @@ class EvoformerBlockCore(nn.Module):
c_hidden_mul, c_hidden_mul,
) )
self.tri_att_start = TriangleAttentionStartingNode( self.tri_att_start = TriangleAttention(
c_z, c_z,
c_hidden_pair_att, c_hidden_pair_att,
no_heads_pair, no_heads_pair,
inf=inf, inf=inf,
) )
self.tri_att_end = TriangleAttentionEndingNode( self.tri_att_end = TriangleAttention(
c_z, c_z,
c_hidden_pair_att, c_hidden_pair_att,
no_heads_pair, no_heads_pair,
...@@ -174,18 +175,16 @@ class EvoformerBlockCore(nn.Module): ...@@ -174,18 +175,16 @@ class EvoformerBlockCore(nn.Module):
) )
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward( def forward(self,
self, input_tensors: Sequence[torch.Tensor],
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of # should be disabled to better approximate the exact activations of
...@@ -196,6 +195,8 @@ class EvoformerBlockCore(nn.Module): ...@@ -196,6 +195,8 @@ class EvoformerBlockCore(nn.Module):
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
m, z = input_tensors
# Need to dodge activation checkpoints # Need to dodge activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
...@@ -205,13 +206,26 @@ class EvoformerBlockCore(nn.Module): ...@@ -205,13 +206,26 @@ class EvoformerBlockCore(nn.Module):
m, mask=msa_trans_mask, chunk_size=chunk_size, m, mask=msa_trans_mask, chunk_size=chunk_size,
), ),
inplace=inplace_safe, inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
del m, z
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, _inplace=inplace_safe
) )
z = add(z,
self.outer_product_mean( if(_offload_inference and inplace_safe):
m, mask=msa_mask, chunk_size=chunk_size, _inplace=inplace_safe del m, z
), input_tensors[0] = input_tensors[0].cpu()
inplace=inplace_safe, input_tensors[1] = input_tensors[1].to(opm.device)
) m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
tmu_update = self.tri_mul_out( tmu_update = self.tri_mul_out(
z, z,
...@@ -250,17 +264,30 @@ class EvoformerBlockCore(nn.Module): ...@@ -250,17 +264,30 @@ class EvoformerBlockCore(nn.Module):
), ),
inplace=inplace_safe, inplace=inplace_safe,
) )
z = add(z,
self.ps_dropout_col_layer( z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.ps_dropout_row_layer(
self.tri_att_end( self.tri_att_end(
z, z,
mask=pair_mask, mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
inplace=inplace_safe, inplace=inplace_safe,
) )
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z, z = add(z,
self.pair_transition( self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size, z, mask=pair_trans_mask, chunk_size=chunk_size,
...@@ -268,6 +295,13 @@ class EvoformerBlockCore(nn.Module): ...@@ -268,6 +295,13 @@ class EvoformerBlockCore(nn.Module):
inplace=inplace_safe, inplace=inplace_safe,
) )
if(_offload_inference and inplace_safe):
device = z.device
del m, z
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z return m, z
...@@ -321,23 +355,22 @@ class EvoformerBlock(nn.Module): ...@@ -321,23 +355,22 @@ class EvoformerBlock(nn.Module):
) )
def forward(self, def forward(self,
m: torch.Tensor, input_tensors: Sequence[torch.Tensor],
z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
print(chunk_size)
print(_attn_chunk_size)
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
m, z = input_tensors
m = add(m, m = add(m,
self.msa_dropout_layer( self.msa_dropout_layer(
self.msa_att_row( self.msa_att_row(
...@@ -359,18 +392,29 @@ class EvoformerBlock(nn.Module): ...@@ -359,18 +392,29 @@ class EvoformerBlock(nn.Module):
), ),
inplace=inplace_safe, inplace=inplace_safe,
) )
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
del m, z
m, z = self.core( m, z = self.core(
m, input_tensors,
z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size, _attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
) )
return m, z if(inplace_safe):
out = input_tensors
else:
out = [m, z]
return out
class ExtraMSABlock(nn.Module): class ExtraMSABlock(nn.Module):
...@@ -433,19 +477,21 @@ class ExtraMSABlock(nn.Module): ...@@ -433,19 +477,21 @@ class ExtraMSABlock(nn.Module):
) )
def forward(self, def forward(self,
m: torch.Tensor, input_tensors: Sequence[torch.Tensor],
z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
_chunk_logits: Optional[int] = 1024,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: _offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
m, z = input_tensors
inplace_safe = not (self.training or torch.is_grad_enabled())
# If function calls could speak... # If function calls could speak...
m = add(m, m = add(m,
self.msa_dropout_layer( self.msa_dropout_layer(
...@@ -455,44 +501,50 @@ class ExtraMSABlock(nn.Module): ...@@ -455,44 +501,50 @@ class ExtraMSABlock(nn.Module):
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma, use_memory_efficient_kernel=not use_lma,
_chunk_logits=
_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks= _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False, self.ckpt if torch.is_grad_enabled() else False,
) )
), ),
inplace=not (self.training or torch.is_grad_enabled()), inplace=inplace_safe,
) )
del m, z
def fn(m, z): def fn(input_tensors):
m = add(m, m = add(input_tensors[0],
self.msa_att_col( self.msa_att_col(
m, input_tensors[0],
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
), ),
inplace=not (self.training or torch.is_grad_enabled()), inplace=inplace_safe,
) )
if(not inplace_safe):
input_tensors [m, input_tensors[1]]
del m
m, z = self.core( m, z = self.core(
m, input_tensors,
z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size _attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
) )
return m, z return m, z
if(torch.is_grad_enabled() and self.ckpt): if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn() checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z) m, z = checkpoint_fn(fn, input_tensors)
else: else:
m, z = fn(m, z) m, z = fn(input_tensors)
return m, z return m, z
...@@ -595,37 +647,15 @@ class EvoformerStack(nn.Module): ...@@ -595,37 +647,15 @@ class EvoformerStack(nn.Module):
if(tune_chunk_size): if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner() self.chunk_size_tuner = ChunkSizeTuner()
def forward(self, def _forward_list(self,
m: torch.Tensor, input_tensors: Sequence[torch.Tensor],
z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: _offload_inference: bool = False,
""" ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [ blocks = [
partial( partial(
b, b,
...@@ -634,6 +664,7 @@ class EvoformerStack(nn.Module): ...@@ -634,6 +664,7 @@ class EvoformerStack(nn.Module):
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_offload_inference=_offload_inference,
) )
for b in self.blocks for b in self.blocks
] ]
...@@ -646,9 +677,11 @@ class EvoformerStack(nn.Module): ...@@ -646,9 +677,11 @@ class EvoformerStack(nn.Module):
blocks = [partial(block_with_cache_clear, b) for b in blocks] blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None): if(chunk_size is not None and self.chunk_size_tuner is not None):
print("evo")
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
args=(m,z), # We don't want to write in-place during chunk tuning runs
args=([t.clone() for t in input_tensors],),
min_chunk_size=chunk_size, min_chunk_size=chunk_size,
) )
blocks = [ blocks = [
...@@ -666,14 +699,54 @@ class EvoformerStack(nn.Module): ...@@ -666,14 +699,54 @@ class EvoformerStack(nn.Module):
m, z = checkpoint_blocks( m, z = checkpoint_blocks(
blocks, blocks,
args=(m, z), args=input_tensors,
blocks_per_ckpt=blocks_per_ckpt, blocks_per_ckpt=blocks_per_ckpt,
) )[0]
s = self.linear(m[..., 0, :, :]) s = self.linear(m[..., 0, :, :])
return m, z, s return m, z, s
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
return self._forward_list(
[m, z],
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
)
class ExtraMSAStack(nn.Module): class ExtraMSAStack(nn.Module):
""" """
...@@ -730,6 +803,81 @@ class ExtraMSAStack(nn.Module): ...@@ -730,6 +803,81 @@ class ExtraMSAStack(nn.Module):
if(tune_chunk_size): if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner() self.chunk_size_tuner = ChunkSizeTuner()
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
_mask_trans: bool,
):
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
) for b in self.blocks
]
def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache()
return b(*args, **kwargs)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
print("extra")
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=([m.clone(), z.clone()],),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
return blocks
def _forward_list(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
_offload_inference: bool = False,
) -> torch.Tensor:
assert(not self.training)
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(input_tensors, _offload_inference=_offload_inference)
input_tensors[0] = m
input_tensors[1] = z
del m, z
return input_tensors[1]
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
...@@ -754,60 +902,21 @@ class ExtraMSAStack(nn.Module): ...@@ -754,60 +902,21 @@ class ExtraMSAStack(nn.Module):
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
if(not self.chunk_msa_attn): checkpoint_fn = get_checkpoint_fn()
checkpoint_fn = get_checkpoint_fn() blocks = self._prep_blocks(
blocks = [ m=m,
partial( z=z,
b, chunk_size=chunk_size,
msa_mask=msa_mask, use_lma=use_lma,
pair_mask=pair_mask, msa_mask=msa_mask,
chunk_size=chunk_size, pair_mask=pair_mask,
use_lma=use_lma, _mask_trans=_mask_trans,
_chunk_logits=None, )
_mask_trans=_mask_trans,
) for b in self.blocks
]
def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache()
return b(*args, **kwargs)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z))
else:
m, z = b(m, z)
else:
for b in self.blocks:
m, z = b(
m,
z,
msa_mask,
pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans
)
if(self.clear_cache_between_blocks): for b in blocks:
torch.cuda.empty_cache() if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, (m, z))
else:
m, z = b(m, z)
return z return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment