Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
......@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
import torch
import torch.nn as nn
from typing import Tuple, Optional
from typing import Tuple, Sequence, Optional
from functools import partial
from openfold.model.primitives import Linear, LayerNorm
......@@ -29,6 +29,7 @@ from openfold.model.msa import (
from openfold.model.outer_product_mean import OuterProductMean
from openfold.model.pair_transition import PairTransition
from openfold.model.triangular_attention import (
TriangleAttention,
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
......@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer, ChunkSizeTuner
from openfold.utils.tensor_utils import add
class MSATransition(nn.Module):
......@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask):
m = self.layer_norm(m)
m = self.linear_1(m)
m = self.relu(m)
m = self.linear_2(m) * mask
......@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
mask = mask.unsqueeze(-1)
m = self.layer_norm(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
......@@ -140,13 +141,13 @@ class PairStack(nn.Module):
c_hidden_mul,
)
self.tri_att_start = TriangleAttentionStartingNode(
self.tri_att_start = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
self.tri_att_end = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
......@@ -159,32 +160,109 @@ class PairStack(nn.Module):
)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(
self,
z: torch.Tensor,
def forward(self,
input_tensors: Sequence[torch.Tensor],
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_offload_inference: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
pair_trans_mask = pair_mask if _mask_trans else None
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
if (_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
tmu_update = self.tri_mul_out(
z,
mask=pair_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
z = z + self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
if (not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
else:
z = tmu_update
del tmu_update
tmu_update = self.tri_mul_in(
z,
mask=pair_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
z = z + self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size
if (not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
else:
z = tmu_update
del tmu_update
z = add(z,
self.ps_dropout_row_layer(
self.tri_att_start(
z,
mask=pair_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if (inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.ps_dropout_row_layer(
self.tri_att_end(
z,
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if (inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
return z
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
class EvoformerBlock(nn.Module):
......@@ -248,41 +326,134 @@ class EvoformerBlock(nn.Module):
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors
if self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
)
),
inplace=inplace_safe,
)
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
m = add(m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
),
inplace=inplace_safe,
)
m = add(
m,
self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
if not self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
z = self.pair_stack(
z,
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
elif (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
device = input_tensors[0].device
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
if (not inplace_safe):
input_tensors = [m, z]
del m, z
m, z = self.pair_stack(
input_tensors,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size
)
return m, z
......@@ -358,63 +529,140 @@ class ExtraMSABlock(nn.Module):
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]:
def add(m1, m2):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if(torch.is_grad_enabled()):
m1 = m1 + m2
else:
m1 += m2
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
return m1
m, z = input_tensors
if self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
m = add(m, self.msa_dropout_layer(
self.msa_att_row(
m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=chunk_size,
use_memory_efficient_kernel=not _chunk_logits,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not use_lma and m.is_cuda,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
),
inplace=inplace_safe,
)
if (not inplace_safe):
input_tensors = [m, z]
del m, z
def fn(input_tensors):
m, z = input_tensors
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
m = add(m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
),
inplace=inplace_safe,
)
m = add(
m,
self.msa_transition(
m, mask=msa_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
))
def fn(m, z):
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m = add(m, self.msa_transition(
m, mask=msa_mask, chunk_size=chunk_size
))
if not self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
z = self.pair_stack(
z, pair_mask=pair_mask, chunk_size=chunk_size
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
if (not inplace_safe):
input_tensors = [m, z]
del m, z
m, z = self.pair_stack(
input_tensors,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size
)
return m, z
if(torch.is_grad_enabled() and self.ckpt):
if (torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
m, z = checkpoint_fn(fn, input_tensors)
else:
m, z = fn(m, z)
m, z = fn(input_tensors)
return m, z
......@@ -446,6 +694,7 @@ class EvoformerStack(nn.Module):
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
tune_chunk_size: bool = False,
**kwargs,
):
"""
......@@ -482,6 +731,8 @@ class EvoformerStack(nn.Module):
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
"""
super(EvoformerStack, self).__init__()
......@@ -511,14 +762,114 @@ class EvoformerStack(nn.Module):
self.linear = Linear(c_m, c_s)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool,
use_flash: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args, **kwargs):
torch.cuda.empty_cache()
return block(*args, **kwargs)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
# We don't want to write in-place during chunk tuning runs
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
return blocks
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
m, z = input_tensors
s = self.linear(m[..., 0, :, :])
return m, z, s
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
m:
......@@ -529,6 +880,13 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
......@@ -536,33 +894,31 @@ class EvoformerStack(nn.Module):
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
"""
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
blocks_per_ckpt=blocks_per_ckpt,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
......@@ -570,7 +926,6 @@ class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
c_m: int,
c_z: int,
......@@ -589,14 +944,13 @@ class ExtraMSAStack(nn.Module):
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False,
tune_chunk_size: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.ckpt = ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.chunk_msa_attn = chunk_msa_attn
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
......@@ -614,16 +968,107 @@ class ExtraMSAStack(nn.Module):
opm_first=opm_first,
inf=inf,
eps=eps,
ckpt=ckpt if chunk_msa_attn else False,
ckpt=False,
)
self.blocks.append(block)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
) for b in self.blocks
]
def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache()
return b(*args, **kwargs)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
return blocks
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
return input_tensors[1]
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
......@@ -632,6 +1077,8 @@ class ExtraMSAStack(nn.Module):
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
......@@ -639,35 +1086,22 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
if(not self.chunk_msa_attn):
checkpoint_fn = get_checkpoint_fn()
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_chunk_logits=None
) for b in self.blocks
]
def clear_cache(b, *args):
torch.cuda.empty_cache()
return b(*args)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z))
else:
m, z = b(m, z)
else:
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
checkpoint_fn = get_checkpoint_fn()
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, m, z)
else:
m, z = b(m, z)
return z
......@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm,
compute_predicted_aligned_error,
)
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module):
......@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z]
def _forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
......@@ -149,6 +150,13 @@ class DistogramHead(nn.Module):
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits
def forward(self, z):
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
class TMScoreHead(nn.Module):
......
......@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import weakref
import torch
import torch.nn as nn
......@@ -34,12 +35,26 @@ from openfold.model.embedders import (
)
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
embed_templates_average,
embed_templates_offload,
)
import openfold.np.residue_constants as residue_constants
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from openfold.utils.loss import (
compute_plddt,
)
from openfold.utils.tensor_utils import (
add,
dict_multimap,
tensor_tree_map,
)
......@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa
self.config = config.model
self.template_config = self.config.template
self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module
if(self.globals.is_multimer):
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
**self.config["input_embedder"],
)
else:
self.input_embedder = InputEmbedder(
**config["input_embedder"],
**self.config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
**self.config["recycling_embedder"],
)
if(self.globals.is_multimer):
self.template_embedder = TemplateEmbedderMultimer(
template_config,
if (self.template_config.enabled):
if(self.globals.is_multimer):
self.template_embedder = TemplateEmbedderMultimer(
self.template_config,
)
else:
self.template_embedder = TemplateEmbedder(
self.template_config,
)
if (self.extra_msa_config.enabled):
self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"],
)
else:
self.template_embedder = TemplateEmbedder(
template_config,
self.extra_msa_stack = ExtraMSAStack(
**self.extra_msa_config["extra_msa_stack"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
**config["evoformer_stack"],
**self.config["evoformer_stack"],
)
self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**config["structure_module"],
**self.config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
config["heads"],
self.config["heads"],
)
self.config = config
def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if (self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if (self.template_config.offload_templates):
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif (self.template_config.average_templates):
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
return template_embeds
def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary
outputs = {}
......@@ -125,19 +181,38 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
## Initialize the MSA and pair representations
# Initialize the MSA and pair representations
if (self.globals.is_multimer):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
else:
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Initialize the recycling embeddings, if needs be
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
......@@ -161,69 +236,58 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
inplace_safe=inplace_safe,
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
z = add(z, z_prev_emb, inplace=inplace_safe)
# Possibly prevents memory fragmentation
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
if(self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size
)
template_embeds = self.embed_templates(
template_feats,
feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
inplace_safe=inplace_safe,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
z = add(z,
template_embeds.pop("template_pair_embedding"),
inplace_safe,
)
if(
self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled)
"template_single_embedding" in template_embeds
):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
......@@ -253,41 +317,80 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# [*, N, N, C_z]
z = self.extra_msa_stack(
extra_msa_feat,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
del z
# Predict 3D structure
outputs["sm"] = self.structure_module(
s,
z,
outputs,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
......@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = z
z_prev = outputs["pair"]
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
......@@ -379,14 +482,13 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
......@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
......@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
m_1_prev,
z_prev,
x_prev,
prevs,
_recycle=(num_iters > 1)
)
if(not is_final_iter):
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
......@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
use_memory_efficient_kernel: bool,
biases: Optional[List[torch.Tensor]],
chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor:
mha = partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel
)
def fn(m, biases, flash_mask):
m = self.layer_norm_m(m)
return self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
)
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer(
mha,
{
"q_x": m,
"kv_x": m,
"biases": biases,
},
fn,
inputs,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2])
)
......@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
mask: Optional[torch.Tensor],
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
......@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
chunks = []
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z]
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
# [*, N_res, N_res, no_heads]
z = self.linear_z(z)
z = torch.cat(chunks, dim=-3)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
......@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
......@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
......@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
......@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
if(z is not None):
biases.append(z)
chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size)
m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
else:
m = self.layer_norm_m(m)
m = self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
return m
......@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer(
self.global_attention,
fn,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
......@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
......@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m)
#m = self.layer_norm_m(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else:
m = self.global_attention(m=m, mask=mask)
m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......
......@@ -20,7 +20,8 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module):
......@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
no_batch_dims=1,
)
out.append(outer)
outer = torch.stack(out, dim=0)
# For some cursed reason making this distinction saves memory
if(len(out) == 1):
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer
def forward(self,
def _forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m]
m = self.layer_norm(m)
ln = self.layer_norm(m)
# [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask
b = self.linear_2(m) * mask
a = self.linear_1(ln)
a = a * mask
b = self.linear_2(ln)
b = b * mask
del ln
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
......@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
norm = norm + self.eps
# [*, N_res, N_res, C_z]
outer = outer / (self.eps + norm)
if(inplace_safe):
outer /= norm
else:
outer = outer / norm
return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module):
......@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)
# [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask
z = self.linear_2(z)
z = z * mask
return z
......@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
......@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None:
z = self._chunk(z, mask, chunk_size)
else:
......
......@@ -13,24 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import importlib
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import deepspeed
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch
import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
_chunk_slice,
)
DEFAULT_LMA_Q_CHUNK_SIZE=1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096
def _prod(nums):
out = 1
for n in nums:
......@@ -145,26 +160,26 @@ class Linear(nn.Linear):
with torch.no_grad():
self.bias.fill_(0)
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
with torch.no_grad():
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
with torch.no_grad():
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
raise ValueError("Invalid init string.")
if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
class LayerNorm(nn.Module):
......@@ -179,7 +194,11 @@ class LayerNorm(nn.Module):
def forward(self, x):
d = x.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
......@@ -207,7 +226,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
type bfloat16
"""
d = t.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
......@@ -403,8 +426,10 @@ class Attention(nn.Module):
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
......@@ -423,29 +448,41 @@ class Attention(nn.Module):
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
q_chunk_size:
lma_q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
lma_kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if(biases is None):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)):
raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
"If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided"
)
if(use_memory_efficient_kernel and use_lma):
if(use_flash and biases is not None):
raise ValueError(
"Choose one of use_memory_efficient_kernel and use_lma"
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if(sum(attn_options) > 1):
raise ValueError(
"Choose at most one alternative attention algorithm"
)
if(biases is None):
biases = []
# [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x)
# [*, Q, H, C_hidden]
if is_fp16_enabled():
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel):
if(len(biases) > 2):
raise ValueError(
......@@ -459,7 +496,10 @@ class Attention(nn.Module):
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3)
elif(use_flash):
o = _flash_attn(q, k, v, flash_mask)
else:
o = _attention(q, k, v, biases)
o = o.transpose(-2, -3)
......@@ -494,7 +534,11 @@ class GlobalAttention(nn.Module):
self.sigmoid = nn.Sigmoid()
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def forward(self,
m: torch.Tensor,
mask: torch.Tensor,
use_lma: bool = False,
) -> torch.Tensor:
# [*, N_res, C_in]
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
......@@ -511,20 +555,30 @@ class GlobalAttention(nn.Module):
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = softmax_no_cast(a)
if(not use_lma):
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
a += bias
a = softmax_no_cast(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
else:
o = _lma(
q,
k,
v,
[bias],
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
......@@ -552,12 +606,12 @@ def _lma(
q_chunk_size: int,
kv_chunk_size: int,
):
no_q, no_kv = q.shape[-3], k.shape[-3]
no_q, no_kv = q.shape[-2], k.shape[-2]
# [*, Q, H, C_hidden]
# [*, H, Q, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
q_chunk = q[..., q_s: q_s + q_chunk_size, :]
large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]
......@@ -566,24 +620,22 @@ def _lma(
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
]
a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk,
"...hqd,...hkd->...hqk", q_chunk, k_chunk,
)
for b in small_bias_chunks:
a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
......@@ -595,14 +647,80 @@ def _lma(
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs
chunk_values = chunk_values * max_diffs.unsqueeze(-1)
chunk_weights = chunk_weights * max_diffs
all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
return o
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
q = q.half()
k = k.half()
v = v.half()
kv_mask = kv_mask.half()
# [*, B, N, H, C]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# [B_flat, N, H, C]
q = q.reshape(-1, *q.shape[-3:])
k = k.reshape(-1, *k.shape[-3:])
v = v.reshape(-1, *v.shape[-3:])
# Flattened batch size
batch_size = q.shape[0]
# [B_flat * N, H, C]
q = q.reshape(-1, *q.shape[-2:])
q_max_s = n
q_cu_seqlens = torch.arange(
0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device
)
# [B_flat, N, 2, H, C]
kv = torch.stack([k, v], dim=-3)
kv_shape = kv.shape
# [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1)
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p = 0.,
softmax_scale = 1., # q has been scaled already
)
# [*, B, N, H, C]
out = out.reshape(*batch_dims, n, no_heads, c)
out = out.to(dtype=dtype)
return out
......@@ -12,11 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import importlib
import math
import sys
from operator import mul
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Sequence, Union
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import (
......@@ -27,11 +31,12 @@ from openfold.np.residue_constants import (
)
from openfold.utils.geometry.quat_rigid import QuatRigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.geometry.vector import Vec3Array, square_euclidean_distance
from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
)
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
dict_multimap,
......@@ -39,6 +44,8 @@ from openfold.utils.tensor_utils import (
flatten_final_dims,
)
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
class AngleResnetBlock(nn.Module):
def __init__(self, c_hidden):
......@@ -164,6 +171,7 @@ class PointProjection(nn.Module):
super().__init__()
self.return_local_points = return_local_points
self.no_heads = no_heads
self.num_points = num_points
self.linear = Linear(c_hidden, no_heads * 3 * num_points)
......@@ -173,22 +181,30 @@ class PointProjection(nn.Module):
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array], torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
points_local = points_local.reshape(
*points_local.shape[:-1],
self.no_heads,
-1,
)
if isinstance(rigids, Rigid3Array):
points_local = points_local.reshape(
*points_local.shape[:-1],
self.no_heads,
-1,
)
points_local = torch.split(
points_local, points_local.shape[-1] // 3, dim=-1
)
points_local = torch.stack(points_local, dim=-1)
if not isinstance(rigids, Rigid3Array):
points_local = points_local.reshape(
*points_local.shape[:-2], self.no_heads, -1, 3
)
points_global = rigids[..., None, None].apply(points_local)
if(self.return_local_points):
return points_global, points_local
return points_global
return points_global
class InvariantPointAttention(nn.Module):
......@@ -242,8 +258,8 @@ class InvariantPointAttention(nn.Module):
self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer))
self.linear_q_points = PointProjection(
self.c_s,
self.no_qk_points,
self.c_s,
self.no_qk_points,
self.no_heads
)
......@@ -288,6 +304,9 @@ class InvariantPointAttention(nn.Module):
z: torch.Tensor,
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
......@@ -302,6 +321,11 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
if (_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
#######################################
# Generate scalar and point activations
#######################################
......@@ -312,7 +336,7 @@ class InvariantPointAttention(nn.Module):
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, P_qk]
q_pts = self.linear_q_points(s, r)
q_pts = self.linear_q_points(s, r)
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
......@@ -351,13 +375,25 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z)
b = self.linear_b(z[0])
if (_offload_inference):
assert (sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
if (is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
)
else:
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
......@@ -369,7 +405,12 @@ class InvariantPointAttention(nn.Module):
pt_att = sum([c ** 2 for c in pt_att])
else:
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
if (inplace_safe):
pt_att *= pt_att
else:
pt_att = pt_att ** 2
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
......@@ -378,7 +419,11 @@ class InvariantPointAttention(nn.Module):
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
)
pt_att = pt_att * head_weights
if (inplace_safe):
pt_att *= head_weights
else:
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
......@@ -388,9 +433,21 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
if (inplace_safe):
a += pt_att
del pt_att
a += square_mask.unsqueeze(-3)
# in-place softmax
attn_core_inplace_cuda.forward_(
a,
reduce(mul, a.shape[:-1]),
a.shape[-1],
)
else:
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################
# Compute output
......@@ -419,13 +476,22 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, H, 3, N_res, P_v]
if (inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
for v in torch.unbind(v_pts, dim=-3)
]
o_pt = torch.stack(o_pt, dim=-3)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
......@@ -440,8 +506,11 @@ class InvariantPointAttention(nn.Module):
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
if (_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
......@@ -450,7 +519,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out(
torch.cat(
(o, *o_pt, o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
).to(dtype=z[0].dtype)
)
return s
......@@ -611,11 +680,11 @@ class StructureModule(nn.Module):
self.inf = inf
self.is_multimer = is_multimer
# To be lazily initialized later
self.default_frames = None
self.group_idx = None
self.atom_mask = None
self.lit_positions = None
# Buffers to be lazily initialized later
# self.default_frames
# self.group_idx
# self.atom_mask
# self.lit_positions
self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = LayerNorm(self.c_z)
......@@ -655,62 +724,32 @@ class StructureModule(nn.Module):
self.no_angles,
self.epsilon,
)
def _init_residue_constants(self, float_dtype, device):
if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.dtype, r.device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
def _forward_monomer(self,
s,
z,
def _forward_monomer(
self,
evoformer_output_dict,
aatype,
mask=None,
inplace_safe=False,
_offload_inference=False,
):
"""
Args:
evoformer_output_dict:
Dictionary containing:
"single":
[*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
s = evoformer_output_dict["single"]
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
......@@ -719,7 +758,14 @@ class StructureModule(nn.Module):
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(z)
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if (_offload_inference):
assert (sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s]
s_initial = s
......@@ -736,11 +782,19 @@ class StructureModule(nn.Module):
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
s = s + self.ipa(
s,
z,
rigids,
mask,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s))
......@@ -781,24 +835,35 @@ class StructureModule(nn.Module):
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
"states": s,
}
outputs.append(preds)
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if (_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def _forward_multimer(self,
s,
z,
aatype,
mask=None,
def _forward_multimer(
self,
evoformer_output_dict,
aatype,
mask=None,
inplace_safe=False,
_offload_inference=False,
):
s = evoformer_output_dict["single"]
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
......@@ -807,7 +872,14 @@ class StructureModule(nn.Module):
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(z)
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if (_offload_inference):
assert (sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s]
s_initial = s
......@@ -821,7 +893,15 @@ class StructureModule(nn.Module):
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
s = s + self.ipa(
s,
z,
rigids,
mask,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
......@@ -848,13 +928,19 @@ class StructureModule(nn.Module):
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz.to_tensor(),
"positions": pred_xyz,
}
outputs.append(preds)
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if (_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
......@@ -863,10 +949,11 @@ class StructureModule(nn.Module):
def forward(
self,
s,
z,
evoformer_output_dict,
aatype,
mask=None,
inplace_safe=False,
_offload_inference=False,
):
"""
Args:
......@@ -882,8 +969,73 @@ class StructureModule(nn.Module):
A dictionary of outputs
"""
if(self.is_multimer):
outputs = self._forward_multimer(s, z, aatype, mask)
outputs = self._forward_multimer(evoformer_output_dict, aatype, mask, inplace_safe, _offload_inference)
else:
outputs = self._forward_monomer(s, z, aatype, mask)
outputs = self._forward_monomer(evoformer_output_dict, aatype, mask, inplace_safe, _offload_inference)
return outputs
def _init_residue_constants(self, float_dtype, device):
if not hasattr(self, "default_frames"):
self.register_buffer(
"default_frames",
torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "group_idx"):
self.register_buffer(
"group_idx",
torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "atom_mask"):
self.register_buffer(
"atom_mask",
torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "lit_positions"):
self.register_buffer(
"lit_positions",
torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.dtype, r.device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
......@@ -14,6 +14,7 @@
# limitations under the License.
from functools import partial
import math
import sys
from typing import Optional, List
import torch
......@@ -34,10 +35,19 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.tensor_utils import (
from openfold.utils.chunk_utils import (
chunk_layer,
ChunkSizeTuner,
)
from openfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.utils.tensor_utils import (
add,
permute_final_dims,
flatten_final_dims,
tensor_tree_map,
)
......@@ -77,6 +87,7 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_inputs = {
"q_x": z,
......@@ -84,7 +95,7 @@ class TemplatePointwiseAttention(nn.Module):
"biases": biases,
}
return chunk_layer(
self.mha,
partial(self.mha, use_lma=use_lma),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
......@@ -95,7 +106,9 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
# This module suffers greatly from a small chunk size
chunk_size: Optional[int] = 256,
use_lma: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -121,10 +134,10 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
biases = [bias]
if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size)
if chunk_size is not None and not self.training:
z = self._chunk(z, t, biases, chunk_size, use_lma=use_lma)
else:
z = self.mha(q_x=z, kv_x=t, biases=biases)
z = self.mha(q_x=z, kv_x=t, biases=biases, use_lma=use_lma)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
......@@ -186,74 +199,118 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n,
)
def tri_att_start_end(self, single, single_mask, chunk_size):
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
def tri_att_start_end(self, single, _attn_chunk_size, single_mask, use_lma, inplace_safe):
single = add(single,
self.dropout_row(
self.tri_att_start(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
single = add(single,
self.dropout_col(
self.tri_att_end(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
return single
def tri_mul_out_in(self, single, single_mask):
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
)
def tri_mul_out_in(self, single, single_mask, inplace_safe):
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
single = single + self.dropout_row(
self.tri_mul_in(
single,
mask=single_mask
)
if (not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
tmu_update = self.tri_mul_in(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if (not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
return single
def forward(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
):
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
]
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
if self.tri_mul_first:
single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
single_mask=single_mask),
single_mask=single_mask,
inplace_safe=inplace_safe),
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
chunk_size=chunk_size)
use_lma=use_lma,
inplace_safe=inplace_safe)
else:
single = self.tri_mul_out_in(single=self.tri_att_start_end(single=single,
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
chunk_size=chunk_size),
single_mask=single_mask)
single = single + self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
)
single_templates[i] = single
z = torch.cat(single_templates, dim=-4)
use_lma=use_lma,
inplace_safe=inplace_safe),
single_mask=single_mask,
inplace_safe=inplace_safe)
single = add(single,
self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
),
inplace_safe,
)
if (not inplace_safe):
single_templates[i] = single
if (not inplace_safe):
z = torch.cat(single_templates, dim=-4)
return z
......@@ -273,6 +330,7 @@ class TemplatePairStack(nn.Module):
dropout_rate,
tri_mul_first,
blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9,
**kwargs,
):
......@@ -314,11 +372,18 @@ class TemplatePairStack(nn.Module):
self.layer_norm = LayerNorm(c_t)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
):
"""
......@@ -335,16 +400,34 @@ class TemplatePairStack(nn.Module):
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
blocks = [
partial(
b,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
t, = checkpoint_blocks(
blocks=[
partial(
b,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
blocks=blocks,
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
......@@ -352,3 +435,223 @@ class TemplatePairStack(nn.Module):
t = self.layer_norm(t)
return t
def embed_templates_offload(
model,
batch,
z,
pair_mask,
templ_dim,
template_chunk_size=256,
inplace_safe=False,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
template_chunk_size:
Integer value controlling how quickly the offloaded pair embedding
tensor is brought back into GPU memory. In dire straits, can be
lowered to reduce memory consumption of this function even more.
Returns:
A dictionary of template pair and angle embeddings.
A version of the "embed_templates" method of the AlphaFold class that
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=model.config.template.use_unit_vector,
inf=model.config.template.inf,
eps=model.config.template.eps,
**model.config.template.distogram,
).to(z.dtype)
t = model.template_pair_embedder(t)
# [*, 1, N, N, C_z]
t = model.template_pair_stack(
t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans,
)
assert(sys.getrefcount(t) == 2)
pair_embeds_cpu.append(t.cpu())
del t
# Preallocate the output tensor
t = z.new_zeros(z.shape)
for i in range(0, n, template_chunk_size):
pair_chunks = [
p[..., i: i + template_chunk_size, :, :] for p in pair_embeds_cpu
]
pair_chunk = torch.cat(pair_chunks, dim=templ_dim).to(device=z.device)
z_chunk = z[..., i: i + template_chunk_size, :, :]
att_chunk = model.template_pointwise_att(
pair_chunk,
z_chunk,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=model.globals.use_lma,
)
t[..., i: i + template_chunk_size, :, :] = att_chunk
del pair_chunks
if(inplace_safe):
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {}
if model.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch,
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": t})
return ret
def embed_templates_average(
model,
batch,
z,
pair_mask,
templ_dim,
templ_group_size=2,
inplace_safe=False,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
templ_group_size:
Granularity of the approximation. Larger values trade memory for
greater proximity to the original function
Returns:
A dictionary of template pair and angle embeddings.
A memory-efficient approximation of the "embed_templates" method of the
AlphaFold class. Instead of running pointwise attention over pair
embeddings for all of the templates at the same time, it splits templates
into groups of size templ_group_size, computes embeddings for each group
normally, and then averages the group embeddings. In our experiments, this
approximation has a minimal effect on the quality of the resulting
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
# Embed the templates one at a time (with a poor man's vmap)
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
out_tensor = z.new_zeros(z.shape)
for i in range(0, n_templ, templ_group_size):
def slice_template_tensor(t):
s = [slice(None) for _ in t.shape]
s[templ_dim] = slice(i, i + templ_group_size)
return t[s]
template_feats = tensor_tree_map(
slice_template_tensor,
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
template_feats,
use_unit_vector=model.config.template.use_unit_vector,
inf=model.config.template.inf,
eps=model.config.template.eps,
**model.config.template.distogram,
).to(z.dtype)
# [*, S_t, N, N, C_z]
t = model.template_pair_embedder(t)
t = model.template_pair_stack(
t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans,
)
t = model.template_pointwise_att(
t,
z,
template_mask=template_feats["template_mask"].to(dtype=z.dtype),
use_lma=model.globals.use_lma,
)
denom = math.ceil(n_templ / templ_group_size)
if(inplace_safe):
t /= denom
else:
t = t / denom
if(inplace_safe):
out_tensor += t
else:
out_tensor = out_tensor + t
del t
if(inplace_safe):
out_tensor *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
out_tensor = out_tensor * (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {}
if model.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch,
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": out_tensor})
return ret
......@@ -21,8 +21,8 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
......@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class TriangleAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, starting, inf=1e9
self, c_in, c_hidden, no_heads, starting=True, inf=1e9
):
"""
Args:
......@@ -62,23 +62,36 @@ class TriangleAttention(nn.Module):
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
partial(self.mha),
partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
)
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -86,15 +99,14 @@ class TriangleAttention(nn.Module):
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
"""
if mask is None:
# [*, I, J]
mask = x.new_ones(
x.shape[:-1],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if not self.starting:
if(not self.starting):
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
......@@ -113,27 +125,35 @@ class TriangleAttention(nn.Module):
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size)
x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else:
x = self.mha(q_x=x, kv_x=x, biases=biases)
x = self.mha(
q_x=x,
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
)
if not self.starting:
if(not self.starting):
x = x.transpose(-2, -3)
return x
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
# Implements Algorithm 13
TriangleAttentionStartingNode = TriangleAttention
class TriangleAttentionEndingNode(TriangleAttention):
"""
Implements Algorithm 14.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
......@@ -20,7 +20,9 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import permute_final_dims
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
......@@ -55,12 +57,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
def _combine_projections(self,
a: torch.Tensor,
b: torch.Tensor,
_inplace_chunk_size: Optional[int] = None
) -> torch.Tensor:
raise NotImplementedError("This method needs to be overridden")
if(self._outgoing):
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
if(_inplace_chunk_size is not None):
# To be replaced by torch vmap
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
a[..., i: i + _inplace_chunk_size, :, :] = (
torch.matmul(
a_chunk,
b_chunk,
)
)
p = a
else:
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade
memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function.
Uses in-place operations, fusion of the addition that happens after
this module in the Evoformer, a smidge of recomputation, and
a cache of overwritten values to lower peak memory consumption of this
module from 5x the size of the input tensor z to 2.5x its size. Useful
for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the
default forward implementation below. Naively, triangle multiplication
attention requires the manifestation of 5 tensors the size of z:
1) z, the "square" input tensor, 2) a, the first projection of z,
3) b, the second projection of b, 4) g, a z-sized mask, and 5) a
z-sized tensor for intermediate computations. For large N, this is
prohibitively expensive; for N=4000, for example, z is more than 8GB
alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a
chunk of the output depend only on the tensor a and corresponding
vertical and horizontal chunks of z. This suggests an algorithm that
loops over pairs of chunks of z: hereafter "columns" and "rows" of
z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing
output chunks to a new tensor would bring total memory consumption
down to 3x the size of z. However, more memory can be saved by writing
output chunks directly to z in-place. WLOG, we choose to write output
chunks vertically, overwriting the ith "column" of z at the end of
the ith iteration of the main loop. Despite this overwriting, the
ith column is always one column ahead of previously overwritten columns
and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For
this reason, we introduce the z-cache, a tensor one-half the size of
z. The z-cache initially contains the left half (2nd and 3rd quadrants)
of z. For 0 < i < N/2, the missing left part of the ith row of z is
recovered from this cache at the beginning of the ith iteration. Once i
exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th
quadrants of z instead. Though the 3rd quadrant of the original z is
entirely overwritten at this point, it can be recovered from the z-cache
itself. Thereafter, the ith row of z can be recovered in its entirety
from the reoriented z-cache. After the final iteration, z has been
completely overwritten and contains the triangular multiplicative
update. If with_add is True, it instead contains the sum of z and the
triangular multiplicative update. In either case, peak memory
consumption is just 2.5x the size of z, disregarding memory used for
chunks and other small variables.
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask, a=True):
if(a):
linear_g = self.linear_a_g
linear_p = self.linear_a_p
else:
linear_g = self.linear_b_g
linear_p = self.linear_b_p
pair = self.layer_norm_in(pair)
p = linear_g(pair)
p.sigmoid_()
p *= linear_p(pair)
p *= mask
p = permute_final_dims(p, (2, 0, 1))
return p
def compute_projection(pair, mask, a=True, chunked=True):
need_transpose = self._outgoing ^ a
if(not chunked):
p = compute_projection_helper(pair, mask, a)
if(need_transpose):
p = p.transpose(-1, -2)
else:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g = self.linear_a_g if a else self.linear_b_g
c = linear_g.bias.shape[-1]
out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
p = pair.new_zeros(out_shape)
for i in range(0, pair.shape[-3], inplace_chunk_size):
pair_chunk = pair[..., i: i + inplace_chunk_size, :, :]
mask_chunk = mask[..., i: i + inplace_chunk_size, :, :]
pair_chunk = compute_projection_helper(
pair[..., i: i + inplace_chunk_size, :, :],
mask[..., i: i + inplace_chunk_size, :, :],
a,
)
if(need_transpose):
pair_chunk = pair_chunk.transpose(-1, -2)
p[..., i: i + inplace_chunk_size] = pair_chunk
else:
p[..., i: i + inplace_chunk_size, :] = pair_chunk
del pair_chunk
return p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a = compute_projection(z, mask, True, chunked=True)
if(inplace_chunk_size is not None):
n = a.shape[-1]
half_n = n // 2 + n % 2
row_dim = -3
col_dim = -2
b_chunk_dim = row_dim if self._outgoing else col_dim
def empty_slicer(t):
return [slice(None) for _ in t.shape]
def slice_tensor(t, start, end, dim):
# Slices start:end from the dim dimension of t
s = empty_slicer(t)
s[dim] = slice(start, end)
return t[s]
def flip_z_cache_(z_cache, z):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3 = slice_tensor(
z_cache, half_n, None, row_dim
)
z_cache = z_cache.transpose(row_dim, col_dim)
# If n is odd, we need to shrink the z_cache by one row
z_cache = z_cache[..., :(n // 2), :, :]
# Move the 3rd quadrant of z into the
first_half_slicer = empty_slicer(z_cache)
first_half_slicer[col_dim] = slice(0, half_n)
z_cache[first_half_slicer] = quadrant_3
# Get the fourth quadrant of z
quadrant_4 = slice_tensor(z, half_n, None, row_dim)
quadrant_4 = slice_tensor(
quadrant_4, half_n, None, col_dim
)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer = empty_slicer(z_cache)
quadrant_3_slicer[col_dim] = slice(half_n, None)
z_cache[quadrant_3_slicer] = quadrant_4
return z_cache
# Initialize the z cache to the left half of z.
z_cache_shape = list(z.shape)
z_cache_shape[col_dim] = half_n
z_cache = z.new_zeros(z_cache_shape)
z_cache_slicer = empty_slicer(z_cache)
z_cache_slicer[col_dim] = slice(0, half_n)
z_cache.copy_(z[z_cache_slicer])
z_cache_rotated = False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range = list(range(0, half_n, inplace_chunk_size))
initial_offsets = [
i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])
]
after_half = list(range(half_n, n, inplace_chunk_size))
after_half_offsets = [inplace_chunk_size for _ in after_half]
combined_range_with_offsets = zip(
i_range + after_half, initial_offsets + after_half_offsets
)
for i, offset in combined_range_with_offsets:
if(not z_cache_rotated and i >= half_n):
z_cache = flip_z_cache_(z_cache, z)
z_cache_rotated = True
z_chunk_b = slice_tensor(
z, i, i + offset, b_chunk_dim,
)
mask_chunk = slice_tensor(
mask, i, i + offset, b_chunk_dim,
)
z_chunk_b = z_chunk_b.clone()
if(b_chunk_dim == col_dim):
z_chunk_b = slice_tensor(
z, i, i + offset, col_dim
)
else: # b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if(not z_cache_rotated):
z_chunk_slicer = empty_slicer(z_chunk_b)
z_chunk_slicer[col_dim] = slice(0, half_n)
z_chunk_b[z_chunk_slicer] = slice_tensor(
z_cache, i, i + offset, row_dim,
)
else:
z_cache_offset = i - half_n
z_chunk_b = slice_tensor(
z_cache,
z_cache_offset, z_cache_offset + offset,
row_dim
)
b_chunk = compute_projection(
z_chunk_b, mask_chunk, a=False, chunked=False
)
del z_chunk_b
x_chunk = torch.matmul(
a,
b_chunk,
)
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
x_chunk = self.layer_norm_out(x_chunk)
x_chunk = self.linear_z(x_chunk)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g = slice_tensor(
z, i, i + offset, col_dim
)
g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
g_chunk.sigmoid_()
del z_chunk_g
x_chunk *= g_chunk
# Write the columns into z in-place
z_slicer = empty_slicer(z)
z_slicer[col_dim] = slice(i, i + offset)
if(with_add):
z[z_slicer] += x_chunk
else:
z[z_slicer] = x_chunk
else:
b = compute_projection(z, mask, False, False)
x = torch.matmul(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g.sigmoid_()
x *= g
if(with_add):
z += x
else:
z = x
return z
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor:
"""
Args:
......@@ -71,57 +371,52 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if(inplace_safe):
x = self._inference_forward(
z,
mask,
inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
a = a * mask
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
b = b * mask
x = self._combine_projections(a, b)
a = mask
a = a * self.sigmoid(self.linear_a_g(z))
a = a * self.linear_a_p(z)
b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
z = x * g
x = x * g
return z
return x
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
def _combine_projections(self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=True)
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
def _combine_projections(self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
......@@ -16,8 +16,9 @@
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
from typing import Any, Sequence, Mapping, Optional
import re
import string
from openfold.np import residue_constants
from Bio.PDB import PDBParser
......@@ -51,16 +52,25 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index: Optional[np.ndarray] = None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark: Optional[str] = None
# Templates used to generate this protein (prediction-only)
parents: Optional[Sequence[str]] = None
# Chain corresponding to each parent
parents_chain_index: Optional[Sequence[int]] = None
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
......@@ -104,7 +114,6 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
......@@ -129,17 +138,32 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints
parents = None
parents_chain_index = None
if("PARENT" in pdb_str):
parents = []
parents_chain_index = []
chain_id = 0
for l in pdb_str.split("\n"):
if("PARENT" in l):
if(not "N/A" in l):
parent_names = l.split()[1:]
parents.extend(parent_names)
parents_chain_index.extend([
chain_id for _ in parent_names
])
chain_id += 1
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
......@@ -149,6 +173,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
parents=parents,
parents_chain_index=parents_chain_index,
)
......@@ -213,6 +239,78 @@ def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
)
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
pdb_headers = []
remark = prot.remark
if(remark is not None):
pdb_headers.append(f"REMARK {remark}")
parents = prot.parents
parents_chain_index = prot.parents_chain_index
if(parents_chain_index is not None):
parents = [
p for i, p in zip(parents_chain_index, parents) if i == chain_id
]
if(parents is None or len(parents) == 0):
parents = ["N/A"]
pdb_headers.append(f"PARENT {' '.join(parents)}")
return pdb_headers
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
""" Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines = []
lines = pdb_str.split('\n')
remark = prot.remark
if(remark is not None):
out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain = None
if(prot.parents is not None and len(prot.parents) > 0):
parents_per_chain = []
if(prot.parents_chain_index is not None):
cur_chain = prot.parents_chain_index[0]
parent_dict = {}
for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p)
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
for i in range(max_idx + 1):
chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents)
else:
parents_per_chain.append(prot.parents)
else:
parents_per_chain = [["N/A"]]
make_parent_line = lambda p: f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
chain_counter = 0
for i, l in enumerate(lines):
if("PARENT" not in l and "REMARK" not in l):
out_pdb_lines.append(l)
if("TER" in l and not "END" in lines[i + 1]):
chain_counter += 1
if(not chain_counter >= len(parents_per_chain)):
chain_parents = parents_per_chain[chain_counter]
else:
chain_parents = ["N/A"]
out_pdb_lines.append(make_parent_line(chain_parents))
return '\n'.join(out_pdb_lines)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
......@@ -232,8 +330,8 @@ def to_pdb(prot: Protein) -> str:
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index.astype(np.int32)
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
......@@ -247,9 +345,17 @@ def to_pdb(prot: Protein) -> str:
)
chain_ids[i] = PDB_CHAIN_IDS[i]
headers = get_pdb_headers(prot)
if (len(headers) > 0):
pdb_lines.extend(headers)
pdb_lines.append("MODEL 1")
n = aatype.shape[0]
atom_index = 1
last_chain_index = chain_index[0]
prev_chain_index = 0
chain_tags = string.ascii_uppercase
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
......@@ -281,10 +387,17 @@ def to_pdb(prot: Protein) -> str:
0
] # Protein supports only C, N, O, S, this works.
charge = ""
chain_tag = "A"
if(chain_index is not None):
chain_tag = chain_tags[chain_index[i]]
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
#TODO: check this refactor, chose main branch version
#f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{res_name_3:>3} {chain_tag:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
......@@ -293,16 +406,28 @@ def to_pdb(prot: Protein) -> str:
pdb_lines.append(atom_line)
atom_index += 1
# Close the final chain.
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1]
)
)
should_terminate = (i == n - 1)
if(chain_index is not None):
if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
should_terminate = True
prev_chain_index = chain_index[i + 1]
if(should_terminate):
# Close the chain.
chain_end = "TER"
chain_termination_line = (
f"{chain_end:<6}{atom_index:>5} "
f"{res_1to3(aatype[i]):>3} "
f"{chain_tag:>1}{residue_index[i]:>4}"
)
pdb_lines.append(chain_termination_line)
atom_index += 1
if(i != n - 1):
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
......@@ -332,6 +457,9 @@ def from_prediction(
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True,
remark: Optional[str] = None,
parents: Optional[Sequence[str]] = None,
parents_chain_index: Optional[Sequence[int]] = None
) -> Protein:
"""Assembles a protein from a prediction.
......@@ -341,7 +469,9 @@ def from_prediction(
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
......@@ -349,7 +479,7 @@ def from_prediction(
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
chain_index = _maybe_remove_leading_dim(features["asym_id"]) - 1
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
......@@ -363,6 +493,9 @@ def from_prediction(
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
chain_index=chain_index,
b_factors=b_factors,
chain_index=chain_index,
remark=remark,
parents=parents,
parents_chain_index=parents_chain_index,
)
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
......@@ -28,10 +28,18 @@ import openfold.utils.loss as loss
from openfold.np.relax import cleanup, utils
import ml_collections
import numpy as np
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
try:
# openmm >= 7.6
import openmm
from openmm import unit
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms
......@@ -192,6 +200,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
pdb_string = '\n'.join(['\n'.join(headers), pdb_string])
return pdb_string
......@@ -511,6 +524,9 @@ def run_pipeline(
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)
# We keep the input around to restore metadata deleted by the relaxer
input_prot = prot
exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
......@@ -527,6 +543,11 @@ def run_pipeline(
max_attempts=max_attempts,
use_gpu=use_gpu,
)
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
ret["min_pdb"] = '\n'.join(['\n'.join(headers), ret["min_pdb"]])
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
......
......@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
try:
# openmm >= 7.6
from openmm import app
from openmm.app import element
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
......
......@@ -87,4 +87,7 @@ class AmberRelaxation(object):
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
min_pdb = protein.add_pdb_headers(prot, min_pdb)
return min_pdb, debug_data, violations
......@@ -18,8 +18,14 @@ import io
from openfold.np import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
try:
# openmm >= 7.6
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
......
......@@ -1120,10 +1120,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
......@@ -1279,7 +1279,7 @@ def make_atom14_dists_bounds(
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1)
np.arange(14, dtype=int), (21, 1)
)
......
import os
import glob
import importlib as importlib
from . import kernel
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
] + ["kernel"]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment