Commit fe9ad07e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add offloading to evoformer

parent b40fab25
......@@ -16,7 +16,7 @@
import math
import torch
import torch.nn as nn
from typing import Tuple, Optional
from typing import Tuple, Sequence, Optional
from functools import partial
from openfold.model.primitives import Linear, LayerNorm
......@@ -29,6 +29,7 @@ from openfold.model.msa import (
from openfold.model.outer_product_mean import OuterProductMean
from openfold.model.pair_transition import PairTransition
from openfold.model.triangular_attention import (
TriangleAttention,
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
......@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import add, chunk_layer, ChunkSizeTuner
from openfold.utils.chunk_utils import chunk_layer, ChunkSizeTuner
from openfold.utils.tensor_utils import add
class MSATransition(nn.Module):
......@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask):
m = self.layer_norm(m)
m = self.linear_1(m)
m = self.relu(m)
m = self.linear_2(m) * mask
......@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
mask = mask.unsqueeze(-1)
m = self.layer_norm(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
......@@ -155,13 +156,13 @@ class EvoformerBlockCore(nn.Module):
c_hidden_mul,
)
self.tri_att_start = TriangleAttentionStartingNode(
self.tri_att_start = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
self.tri_att_end = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
......@@ -174,18 +175,16 @@ class EvoformerBlockCore(nn.Module):
)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
def forward(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
......@@ -196,6 +195,8 @@ class EvoformerBlockCore(nn.Module):
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
# Need to dodge activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
......@@ -206,13 +207,26 @@ class EvoformerBlockCore(nn.Module):
),
inplace=inplace_safe,
)
z = add(z,
self.outer_product_mean(
if(_offload_inference and inplace_safe):
del m, z
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, _inplace=inplace_safe
),
inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
del m, z
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
tmu_update = self.tri_mul_out(
z,
mask=pair_mask,
......@@ -250,17 +264,30 @@ class EvoformerBlockCore(nn.Module):
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.ps_dropout_col_layer(
self.ps_dropout_row_layer(
self.tri_att_end(
z,
mask=pair_mask,
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_lma=use_lma,
)
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size,
......@@ -268,6 +295,13 @@ class EvoformerBlockCore(nn.Module):
inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
device = z.device
del m, z
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
......@@ -321,23 +355,22 @@ class EvoformerBlock(nn.Module):
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
inplace_safe = not (self.training or torch.is_grad_enabled())
print(chunk_size)
print(_attn_chunk_size)
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
......@@ -359,18 +392,29 @@ class EvoformerBlock(nn.Module):
),
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
del m, z
m, z = self.core(
m,
z,
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
)
return m, z
if(inplace_safe):
out = input_tensors
else:
out = [m, z]
return out
class ExtraMSABlock(nn.Module):
......@@ -433,19 +477,21 @@ class ExtraMSABlock(nn.Module):
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_chunk_logits: Optional[int] = 1024,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
inplace_safe = not (self.training or torch.is_grad_enabled())
# If function calls could speak...
m = add(m,
self.msa_dropout_layer(
......@@ -455,44 +501,50 @@ class ExtraMSABlock(nn.Module):
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma,
_chunk_logits=
_chunk_logits if torch.is_grad_enabled() else None,
use_memory_efficient_kernel=not use_lma,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
),
inplace=not (self.training or torch.is_grad_enabled()),
inplace=inplace_safe,
)
def fn(m, z):
m = add(m,
del m, z
def fn(input_tensors):
m = add(input_tensors[0],
self.msa_att_col(
m,
input_tensors[0],
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
),
inplace=not (self.training or torch.is_grad_enabled()),
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors [m, input_tensors[1]]
del m
m, z = self.core(
m,
z,
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
)
return m, z
if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
m, z = checkpoint_fn(fn, input_tensors)
else:
m, z = fn(m, z)
m, z = fn(input_tensors)
return m, z
......@@ -595,37 +647,15 @@ class EvoformerStack(nn.Module):
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
def _forward_list(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
blocks = [
partial(
b,
......@@ -634,6 +664,7 @@ class EvoformerStack(nn.Module):
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
_offload_inference=_offload_inference,
)
for b in self.blocks
]
......@@ -646,9 +677,11 @@ class EvoformerStack(nn.Module):
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
print("evo")
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
# We don't want to write in-place during chunk tuning runs
args=([t.clone() for t in input_tensors],),
min_chunk_size=chunk_size,
)
blocks = [
......@@ -666,14 +699,54 @@ class EvoformerStack(nn.Module):
m, z = checkpoint_blocks(
blocks,
args=(m, z),
args=input_tensors,
blocks_per_ckpt=blocks_per_ckpt,
)
)[0]
s = self.linear(m[..., 0, :, :])
return m, z, s
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
return self._forward_list(
[m, z],
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
)
class ExtraMSAStack(nn.Module):
"""
......@@ -730,32 +803,15 @@ class ExtraMSAStack(nn.Module):
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(self,
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
if(not self.chunk_msa_attn):
checkpoint_fn = get_checkpoint_fn()
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
_mask_trans: bool,
):
blocks = [
partial(
b,
......@@ -763,7 +819,6 @@ class ExtraMSAStack(nn.Module):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_chunk_logits=None,
_mask_trans=_mask_trans,
) for b in self.blocks
]
......@@ -776,9 +831,10 @@ class ExtraMSAStack(nn.Module):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
print("extra")
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
args=([m.clone(), z.clone()],),
min_chunk_size=chunk_size,
)
blocks = [
......@@ -790,24 +846,77 @@ class ExtraMSAStack(nn.Module):
) for b in blocks
]
return blocks
def _forward_list(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
_offload_inference: bool = False,
) -> torch.Tensor:
assert(not self.training)
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans,
)
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z))
else:
m, z = b(m, z)
else:
for b in self.blocks:
m, z = b(
m,
z,
msa_mask,
pair_mask,
m, z = b(input_tensors, _offload_inference=_offload_inference)
input_tensors[0] = m
input_tensors[1] = z
del m, z
return input_tensors[1]
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
checkpoint_fn = get_checkpoint_fn()
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans,
)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, (m, z))
else:
m, z = b(m, z)
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment