"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "635f1e94e855d7832363ecdb2ed70affe487608a"
Commit 6e66b218 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Vastly lower peak inference memory usage

parent ec5619fc
...@@ -10,6 +10,29 @@ def set_inf(c, inf): ...@@ -10,6 +10,29 @@ def set_inf(c, inf):
c[k] = inf c[k] = inf
def enforce_config_constraints(config):
def string_to_setting(s):
path = s.split('.')
setting = config
for p in path:
setting = setting[p]
return setting
mutually_exclusive_bools = [
(
"model.template.average_templates",
"model.template.offload_templates"
)
]
for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
def model_config(name, train=False, low_prec=False): def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config) c = copy.deepcopy(config)
if name == "initial_training": if name == "initial_training":
...@@ -22,6 +45,14 @@ def model_config(name, train=False, low_prec=False): ...@@ -22,6 +45,14 @@ def model_config(name, train=False, low_prec=False):
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01 c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_ptm":
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_1": elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1 # AF2 Suppl. Table 5, Model 1.1.1
c.data.train.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
...@@ -95,6 +126,8 @@ def model_config(name, train=False, low_prec=False): ...@@ -95,6 +126,8 @@ def model_config(name, train=False, low_prec=False):
# a global constant # a global constant
set_inf(c, 1e4) set_inf(c, 1e4)
enforce_config_constraints(c)
return c return c
...@@ -346,6 +379,16 @@ config = mlc.ConfigDict( ...@@ -346,6 +379,16 @@ config = mlc.ConfigDict(
"enabled": templates_enabled, "enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles, "embed_angles": embed_template_torsion_angles,
"use_unit_vector": False, "use_unit_vector": False,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates": False,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates.
"offload_templates": False,
}, },
"extra_msa": { "extra_msa": {
"extra_msa_embedder": { "extra_msa_embedder": {
...@@ -498,7 +541,7 @@ config = mlc.ConfigDict( ...@@ -498,7 +541,7 @@ config = mlc.ConfigDict(
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 0.0, "weight": 0.,
"enabled": tm_enabled, "enabled": tm_enabled,
}, },
"eps": eps, "eps": eps,
......
...@@ -625,13 +625,20 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -625,13 +625,20 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.train_chain_data_cache_path, self.train_chain_data_cache_path,
] ]
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset( self.train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths, chain_data_cache_paths=chain_data_cache_paths,
generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
if(self.val_data_dir is not None): if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
...@@ -660,7 +667,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -660,7 +667,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None dataset = None
if(stage == "train"): if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary # Filter the dataset, if necessary
dataset.reroll() dataset.reroll()
elif(stage == "eval"): elif(stage == "eval"):
......
...@@ -97,7 +97,8 @@ def unify_template_features( ...@@ -97,7 +97,8 @@ def unify_template_features(
chain_indices = np.array(n_templates * [i]) chain_indices = np.array(n_templates * [i])
out_dict["template_chain_index"] = chain_indices out_dict["template_chain_index"] = chain_indices
out_dicts.append(out_dict) if(n_templates != 0):
out_dicts.append(out_dict)
out_dict = { out_dict = {
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0] k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
...@@ -741,7 +742,7 @@ class DataPipeline: ...@@ -741,7 +742,7 @@ class DataPipeline:
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter. No templates. hack from Twitter (a.k.a. AlphaFold-Gap).
""" """
with open(fasta_path, 'r') as f: with open(fasta_path, 'r') as f:
fasta_str = f.read() fasta_str = f.read()
......
...@@ -728,6 +728,7 @@ def make_atom14_positions(protein): ...@@ -728,6 +728,7 @@ def make_atom14_positions(protein):
for index, correspondence in enumerate(correspondences): for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0 renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack( renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3] [all_matrices[restype] for restype in restype_3]
) )
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple, Optional
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import one_hot from openfold.utils.tensor_utils import add, one_hot
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
...@@ -132,7 +132,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -132,7 +132,6 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32. Implements Algorithm 32.
""" """
def __init__( def __init__(
self, self,
c_m: int, c_m: int,
...@@ -174,6 +173,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -174,6 +173,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
_inplace: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -189,6 +189,19 @@ class RecyclingEmbedder(nn.Module): ...@@ -189,6 +189,19 @@ class RecyclingEmbedder(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(_inplace):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(_inplace):
z.copy_(z_update)
z_update = z
# This squared method might become problematic in FP16 mode.
bins = torch.linspace( bins = torch.linspace(
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
...@@ -197,13 +210,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -197,13 +210,6 @@ class RecyclingEmbedder(nn.Module):
device=x.device, device=x.device,
requires_grad=False, requires_grad=False,
) )
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2 squared_bins = bins ** 2
upper = torch.cat( upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
...@@ -217,7 +223,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -217,7 +223,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
d = self.linear(d) d = self.linear(d)
z_update = d + self.layer_norm_z(z) z_update = add(z_update, d, _inplace)
return m_update, z_update return m_update, z_update
...@@ -315,7 +321,6 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -315,7 +321,6 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15 Implements Algorithm 2, line 15
""" """
def __init__( def __init__(
self, self,
c_in: int, c_in: int,
......
...@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.tensor_utils import add, chunk_layer
class MSATransition(nn.Module): class MSATransition(nn.Module):
...@@ -192,32 +192,76 @@ class EvoformerBlockCore(nn.Module): ...@@ -192,32 +192,76 @@ class EvoformerBlockCore(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_transition( # Need to dodge activation checkpoints
m, mask=msa_trans_mask, chunk_size=chunk_size, inplace_safe = not (self.training or torch.is_grad_enabled())
m = add(
m,
self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
) )
z = z + self.outer_product_mean( z = add(z,
m, mask=msa_mask, chunk_size=chunk_size, self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, _inplace=inplace_safe
),
inplace=inplace_safe,
) )
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) tmu_update = self.tri_mul_out(
z = z + self.ps_dropout_row_layer( z,
self.tri_att_start( mask=pair_mask,
z, _inplace=inplace_safe,
mask=pair_mask, _add_with_inplace=True,
chunk_size=chunk_size,
use_lma=use_lma
)
) )
z = z + self.ps_dropout_col_layer( if(not inplace_safe):
self.tri_att_end( z = z + self.ps_dropout_row_layer(tmu_update)
z, else:
mask=pair_mask, z = tmu_update
chunk_size=chunk_size,
use_lma=use_lma, del tmu_update
)
tmu_update = self.tri_mul_in(
z,
mask=pair_mask,
_inplace=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
else:
z = tmu_update
del tmu_update
z = add(z,
self.ps_dropout_row_layer(
self.tri_att_start(
z,
mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma
)
),
inplace=inplace_safe,
)
z = add(z,
self.ps_dropout_col_layer(
self.tri_att_end(
z,
mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
),
inplace=inplace_safe,
) )
z = z + self.pair_transition( z = add(z,
z, mask=pair_trans_mask, chunk_size=chunk_size, self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
) )
return m, z return m, z
...@@ -377,40 +421,35 @@ class ExtraMSABlock(nn.Module): ...@@ -377,40 +421,35 @@ class ExtraMSABlock(nn.Module):
use_lma: bool = False, use_lma: bool = False,
_chunk_logits: Optional[int] = 1024, _chunk_logits: Optional[int] = 1024,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
def add(m1, m2): # If function calls could speak...
# The first operation in a checkpoint can't be in-place, but it's m = add(m,
# nice to have in-place addition during inference. Thus... self.msa_dropout_layer(
if(torch.is_grad_enabled()): self.msa_att_row(
m1 = m1 + m2 m.clone() if torch.is_grad_enabled() else m,
else: z=z.clone() if torch.is_grad_enabled() else z,
m1 += m2 mask=msa_mask,
chunk_size=chunk_size,
return m1 use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma,
m = add(m, self.msa_dropout_layer( _chunk_logits=
self.msa_att_row( _chunk_logits if torch.is_grad_enabled() else None,
m.clone() if torch.is_grad_enabled() else m, _checkpoint_chunks=
z=z.clone() if torch.is_grad_enabled() else z, self.ckpt if torch.is_grad_enabled() else False,
mask=msa_mask, )
chunk_size=chunk_size, ),
use_lma=use_lma, inplace=not (self.training or torch.is_grad_enabled()),
use_memory_efficient_kernel=not _chunk_logits and not use_lma, )
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
))
def fn(m, z): def fn(m, z):
m = add( m = add(m,
m,
self.msa_att_col( self.msa_att_col(
m, m,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
) ),
inplace=not (self.training or torch.is_grad_enabled()),
) )
m, z = self.core( m, z = self.core(
m, m,
...@@ -590,7 +629,6 @@ class ExtraMSAStack(nn.Module): ...@@ -590,7 +629,6 @@ class ExtraMSAStack(nn.Module):
""" """
Implements Algorithm 18. Implements Algorithm 18.
""" """
def __init__(self, def __init__(self,
c_m: int, c_m: int,
c_z: int, c_z: int,
......
...@@ -12,18 +12,12 @@ ...@@ -12,18 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import weakref
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
RecyclingEmbedder, RecyclingEmbedder,
...@@ -33,16 +27,26 @@ from openfold.model.embedders import ( ...@@ -33,16 +27,26 @@ from openfold.model.embedders import (
) )
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule from openfold.model.structure_module import StructureModule
from openfold.model.template import ( from openfold.model.template import (
TemplatePairStack, TemplatePairStack,
TemplatePointwiseAttention, TemplatePointwiseAttention,
embed_templates_average,
embed_templates_offload,
)
import openfold.np.residue_constants as residue_constants
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
) )
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add,
dict_multimap, dict_multimap,
tensor_tree_map, tensor_tree_map,
) )
...@@ -64,52 +68,71 @@ class AlphaFold(nn.Module): ...@@ -64,52 +68,71 @@ class AlphaFold(nn.Module):
super(AlphaFold, self).__init__() super(AlphaFold, self).__init__()
self.globals = config.globals self.globals = config.globals
config = config.model self.config = config.model
template_config = config.template self.template_config = self.config.template
extra_msa_config = config.extra_msa self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module # Main trunk + structure module
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
**config["input_embedder"], **self.config["input_embedder"],
) )
self.recycling_embedder = RecyclingEmbedder( self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"], **self.config["recycling_embedder"],
) )
self.template_angle_embedder = TemplateAngleEmbedder( self.template_angle_embedder = TemplateAngleEmbedder(
**template_config["template_angle_embedder"], **self.template_config["template_angle_embedder"],
) )
self.template_pair_embedder = TemplatePairEmbedder( self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"], **self.template_config["template_pair_embedder"],
) )
self.template_pair_stack = TemplatePairStack( self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"], **self.template_config["template_pair_stack"],
) )
self.template_pointwise_att = TemplatePointwiseAttention( self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"], **self.template_config["template_pointwise_attention"],
) )
self.extra_msa_embedder = ExtraMSAEmbedder( self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"], **self.extra_msa_config["extra_msa_embedder"],
) )
self.extra_msa_stack = ExtraMSAStack( self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"], **self.extra_msa_config["extra_msa_stack"],
) )
self.evoformer = EvoformerStack( self.evoformer = EvoformerStack(
**config["evoformer_stack"], **self.config["evoformer_stack"],
) )
self.structure_module = StructureModule( self.structure_module = StructureModule(
**config["structure_module"], **self.config["structure_module"],
) )
self.aux_heads = AuxiliaryHeads( self.aux_heads = AuxiliaryHeads(
config["heads"], self.config["heads"],
) )
self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim): def embed_templates(self, batch, z, pair_mask, templ_dim):
if(self.template_config.offload_templates):
return embed_templates_offload(
self, batch, z, pair_mask, templ_dim,
)
elif(self.template_config.average_templates):
return embed_templates_average(
self, batch, z, pair_mask, templ_dim
)
inplace_safe = not (self.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if(inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.globals.c_t)
)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -117,18 +140,7 @@ class AlphaFold(nn.Module): ...@@ -117,18 +140,7 @@ class AlphaFold(nn.Module):
batch, batch,
) )
single_template_embeds = {} # [*, N, N, C_t]
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.template.use_unit_vector, use_unit_vector=self.config.template.use_unit_vector,
...@@ -138,23 +150,27 @@ class AlphaFold(nn.Module): ...@@ -138,23 +150,27 @@ class AlphaFold(nn.Module):
).to(z.dtype) ).to(z.dtype)
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) if(inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
template_embeds.append(single_template_embeds) if(not inplace_safe):
t_pair = torch.cat(pair_embeds, dim=templ_dim)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim), del pair_embeds
template_embeds,
)
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["pair"], t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del t_pair
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
...@@ -164,17 +180,28 @@ class AlphaFold(nn.Module): ...@@ -164,17 +180,28 @@ class AlphaFold(nn.Module):
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
) )
t = t * (torch.sum(batch["template_mask"]) > 0)
if(inplace_safe):
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {} ret = {}
if self.config.template.embed_angles: if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"] template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
ret.update({"template_pair_embedding": t}) ret.update({"template_pair_embedding": t})
return ret return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
...@@ -190,13 +217,14 @@ class AlphaFold(nn.Module): ...@@ -190,13 +217,14 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2] n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3] n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device device = feats["target_feat"].device
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features # Prep some features
seq_mask = feats["seq_mask"] seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :] pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"] msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations ## Initialize the MSA and pair representations
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
...@@ -206,7 +234,11 @@ class AlphaFold(nn.Module): ...@@ -206,7 +234,11 @@ class AlphaFold(nn.Module):
feats["msa_feat"], feats["msa_feat"],
) )
# Initialize the recycling embeddings, if needs be # Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function.
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m.new_zeros( m_1_prev = m.new_zeros(
...@@ -236,24 +268,16 @@ class AlphaFold(nn.Module): ...@@ -236,24 +268,16 @@ class AlphaFold(nn.Module):
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
_inplace=not (self.training or torch.is_grad_enabled()),
) )
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z] # [*, N, N, C_z]
z += z_prev_emb z += z_prev_emb
# Possibly prevents memory fragmentation # This matters during inference with large N
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
...@@ -269,7 +293,10 @@ class AlphaFold(nn.Module): ...@@ -269,7 +293,10 @@ class AlphaFold(nn.Module):
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = add(z,
template_embeds.pop("template_pair_embedding"),
inplace_safe,
)
if self.config.template.embed_angles: if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
...@@ -289,7 +316,7 @@ class AlphaFold(nn.Module): ...@@ -289,7 +316,7 @@ class AlphaFold(nn.Module):
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.extra_msa_stack( z = self.extra_msa_stack(
a, a,
...@@ -301,6 +328,8 @@ class AlphaFold(nn.Module): ...@@ -301,6 +328,8 @@ class AlphaFold(nn.Module):
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del a
# Run MSA + pair embeddings through the trunk of the network # Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
...@@ -416,6 +445,7 @@ class AlphaFold(nn.Module): ...@@ -416,6 +445,7 @@ class AlphaFold(nn.Module):
""" """
# Initialize recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters # Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
...@@ -440,12 +470,15 @@ class AlphaFold(nn.Module): ...@@ -440,12 +470,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, feats,
m_1_prev, prevs,
z_prev,
x_prev,
_recycle=(num_iters > 1) _recycle=(num_iters > 1)
) )
if(not is_final_iter):
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads # Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
......
...@@ -82,7 +82,13 @@ class OuterProductMean(nn.Module): ...@@ -82,7 +82,13 @@ class OuterProductMean(nn.Module):
no_batch_dims=1, no_batch_dims=1,
) )
out.append(outer) out.append(outer)
outer = torch.stack(out, dim=0)
# For some cursed reason making this distinction saves memory
if(len(out) == 1):
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer return outer
...@@ -90,7 +96,8 @@ class OuterProductMean(nn.Module): ...@@ -90,7 +96,8 @@ class OuterProductMean(nn.Module):
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None chunk_size: Optional[int] = None,
_inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -105,12 +112,17 @@ class OuterProductMean(nn.Module): ...@@ -105,12 +112,17 @@ class OuterProductMean(nn.Module):
mask = m.new_ones(m.shape[:-1]) mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm(m) ln = self.layer_norm(m)
# [*, N_seq, N_res, C] # [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask a = self.linear_1(ln)
b = self.linear_2(m) * mask a = a * mask
b = self.linear_2(ln)
b = b * mask
del ln
a = a.transpose(-2, -3) a = a.transpose(-2, -3)
b = b.transpose(-2, -3) b = b.transpose(-2, -3)
...@@ -122,8 +134,12 @@ class OuterProductMean(nn.Module): ...@@ -122,8 +134,12 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1] # [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask) norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
norm = norm + self.eps
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
outer = outer / (self.eps + norm) if(_inplace):
outer /= norm
else:
outer = outer / norm
return outer return outer
...@@ -34,10 +34,16 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -34,10 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add,
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
tensor_tree_map,
) )
...@@ -191,7 +197,8 @@ class TemplatePairStackBlock(nn.Module): ...@@ -191,7 +197,8 @@ class TemplatePairStackBlock(nn.Module):
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
_mask_trans: bool = True _mask_trans: bool = True,
_inplace: bool = False,
): ):
single_templates = [ single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4) t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
...@@ -202,44 +209,71 @@ class TemplatePairStackBlock(nn.Module): ...@@ -202,44 +209,71 @@ class TemplatePairStackBlock(nn.Module):
for i in range(len(single_templates)): for i in range(len(single_templates)):
single = single_templates[i] single = single_templates[i]
single_mask = single_templates_masks[i] single_mask = single_templates_masks[i]
single = single + self.dropout_row( single = add(single,
self.tri_att_start( self.dropout_row(
single, self.tri_att_start(
chunk_size=chunk_size, single,
mask=single_mask, chunk_size=chunk_size,
use_lma=use_lma, mask=single_mask,
) use_lma=use_lma,
)
),
_inplace,
) )
single = single + self.dropout_col(
self.tri_att_end( single = add(single,
single, self.dropout_col(
chunk_size=chunk_size, self.tri_att_end(
mask=single_mask, single,
use_lma=use_lma, chunk_size=chunk_size,
) mask=single_mask,
) use_lma=use_lma,
single = single + self.dropout_row( )
self.tri_mul_out( ),
single, _inplace,
mask=single_mask
)
) )
single = single + self.dropout_row(
self.tri_mul_in( tmu_update = self.tri_mul_out(
single, single,
mask=single_mask mask=single_mask,
) _inplace=_inplace,
_add_with_inplace=True,
) )
single = single + self.pair_transition( if(not _inplace):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
tmu_update = self.tri_mul_in(
single, single,
mask=single_mask if _mask_trans else None, mask=single_mask,
chunk_size=chunk_size, _inplace=_inplace,
_add_with_inplace=True,
)
if(not _inplace):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
single = add(single,
self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
),
_inplace,
) )
single_templates[i] = single if(not _inplace):
single_templates[i] = single
z = torch.cat(single_templates, dim=-4) if(not _inplace):
z = torch.cat(single_templates, dim=-4)
return z return z
...@@ -328,6 +362,7 @@ class TemplatePairStack(nn.Module): ...@@ -328,6 +362,7 @@ class TemplatePairStack(nn.Module):
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_inplace=not (self.training or torch.is_grad_enabled()),
) )
for b in self.blocks for b in self.blocks
], ],
...@@ -338,3 +373,223 @@ class TemplatePairStack(nn.Module): ...@@ -338,3 +373,223 @@ class TemplatePairStack(nn.Module):
t = self.layer_norm(t) t = self.layer_norm(t)
return t return t
def embed_templates_offload(
model,
batch,
z,
pair_mask,
templ_dim,
template_chunk_size=256,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
template_chunk_size:
Integer value controlling how quickly the offloaded pair embedding
tensor is brought back into GPU memory. In dire straits, can be
lowered to reduce memory consumption of this function even more.
Returns:
A dictionary of template pair and angle embeddings.
A version of the "embed_templates" method of the AlphaFold class that
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=model.config.template.use_unit_vector,
inf=model.config.template.inf,
eps=model.config.template.eps,
**model.config.template.distogram,
).to(z.dtype)
t = model.template_pair_embedder(t)
# [*, 1, N, N, C_z]
t = model.template_pair_stack(
t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans,
)
pair_embeds_cpu.append(t.cpu())
del t
# Preallocate the output tensor
t = z.new_zeros(z.shape)
for i in range(0, n, template_chunk_size):
pair_chunks = [
p[..., i: i + template_chunk_size, :, :] for p in pair_embeds_cpu
]
pair_chunk = torch.cat(pair_chunks, dim=templ_dim).to(device=z.device)
z_chunk = z[..., i: i + template_chunk_size, :, :]
att_chunk = model.template_pointwise_att(
pair_chunk,
z_chunk,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=model.globals.use_lma,
)
t[..., i: i + template_chunk_size, :, :] = att_chunk
del pair_chunks
if(inplace_safe):
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {}
if model.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch,
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
ret.update({"template_pair_embedding": t})
return ret
def embed_templates_average(
model,
batch,
z,
pair_mask,
templ_dim,
templ_group_size=2,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
templ_group_size:
Granularity of the approximation. Larger values trade memory for
greater proximity to the original function
Returns:
A dictionary of template pair and angle embeddings.
A memory-efficient approximation of the "embed_templates" method of the
AlphaFold class. Instead of running pointwise attention over pair
embeddings for all of the templates at the same time, it splits templates
into groups of size templ_group_size, computes embeddings for each group
normally, and then averages the group embeddings. In our experiments, this
approximation has a minimal effect on the quality of the resulting
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
out_tensor = z.new_zeros(z.shape)
for i in range(0, n_templ, templ_group_size):
def slice_template_tensor(t):
s = [slice(None) for _ in t.shape]
s[templ_dim] = slice(i, i + templ_group_size)
return t[s]
template_feats = tensor_tree_map(
slice_template_tensor,
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
template_feats,
use_unit_vector=model.config.template.use_unit_vector,
inf=model.config.template.inf,
eps=model.config.template.eps,
**model.config.template.distogram,
).to(z.dtype)
# [*, S_t, N, N, C_z]
t = model.template_pair_embedder(t)
t = model.template_pair_stack(
t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans,
)
t = model.template_pointwise_att(
t,
z,
template_mask=template_feats["template_mask"].to(dtype=z.dtype),
use_lma=model.globals.use_lma,
)
denom = math.ceil(n_templ / templ_group_size)
if(inplace_safe):
t /= denom
else:
t = t / denom
if(inplace_safe):
out_tensor += t
else:
out_tensor = out_tensor + t
del t
if(inplace_safe):
out_tensor *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
out_tensor = out_tensor * (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {}
if model.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch,
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
ret.update({"template_pair_embedding": out_tensor})
return ret
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import permute_final_dims from openfold.utils.tensor_utils import add, chunk_layer, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module): class TriangleMultiplicativeUpdate(nn.Module):
...@@ -55,12 +55,310 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -55,12 +55,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
def _combine_projections(self, def _combine_projections(self,
a: torch.Tensor, a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
_inplace_chunk_size: Optional[int] = None
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError("This method needs to be overridden") if(self._outgoing):
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
if(_inplace_chunk_size is not None):
# To be replaced by torch vmap
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
a[..., i: i + _inplace_chunk_size, :, :] = (
torch.matmul(
a_chunk,
b_chunk,
)
)
p = a
else:
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade
memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function.
Uses in-place operations, fusion of the addition that happens after
this module in the Evoformer, a smidge of recomputation, and
a cache of overwritten values to lower peak memory consumption of this
module from 5x the size of the input tensor z to 2.5x its size. Useful
for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the
default forward implementation below. Naively, triangle multiplication
attention requires the manifestation of 5 tensors the size of z:
1) z, the "square" input tensor, 2) a, the first projection of z,
3) b, the second projection of b, 4) g, a z-sized mask, and 5) a
z-sized tensor for intermediate computations. For large N, this is
prohibitively expensive; for N=4000, for example, z is more than 8GB
alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a
chunk of the output depend only on the tensor a and corresponding
vertical and horizontal chunks of z. This suggests an algorithm that
loops over pairs of chunks of z: hereafter "columns" and "rows" of
z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing
output chunks to a new tensor would bring total memory consumption
down to 3x the size of z. However, more memory can be saved by writing
output chunks directly to z in-place. WLOG, we choose to write output
chunks vertically, overwriting the ith "column" of z at the end of
the ith iteration of the main loop. Despite this overwriting, the
ith column is always one column ahead of previously overwritten columns
and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For
this reason, we introduce the z-cache, a tensor one-half the size of
z. The z-cache initially contains the left half (2nd and 3rd quadrants)
of z. For 0 < i < N/2, the missing left part of the ith row of z is
recovered from this cache at the beginning of the ith iteration. Once i
exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th
quadrants of z instead. Though the 3rd quadrant of the original z is
entirely overwritten at this point, it can be recovered from the z-cache
itself. Thereafter, the ith row of z can be recovered in its entirety
from the reoriented z-cache. After the final iteration, z has been
completely overwritten and contains the triangular multiplicative
update. If with_add is True, it instead contains the sum of z and the
triangular multiplicative update. In either case, peak memory
consumption is just 2.5x the size of z, disregarding memory used for
chunks and other small variables.
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask, a=True):
if(a):
linear_g = self.linear_a_g
linear_p = self.linear_a_p
else:
linear_g = self.linear_b_g
linear_p = self.linear_b_p
pair = self.layer_norm_in(pair)
p = linear_g(pair)
p.sigmoid_()
p *= linear_p(pair)
p *= mask
p = permute_final_dims(p, (2, 0, 1))
return p
def compute_projection(pair, mask, a=True, chunked=True):
need_transpose = self._outgoing ^ a
if(not chunked):
p = compute_projection_helper(pair, mask, a)
if(need_transpose):
p = p.transpose(-1, -2)
else:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g = self.linear_a_g if a else self.linear_b_g
c = linear_g.bias.shape[-1]
out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
p = pair.new_zeros(out_shape)
for i in range(0, pair.shape[-3], inplace_chunk_size):
pair_chunk = pair[..., i: i + inplace_chunk_size, :, :]
mask_chunk = mask[..., i: i + inplace_chunk_size, :, :]
pair_chunk = compute_projection_helper(
pair[..., i: i + inplace_chunk_size, :, :],
mask[..., i: i + inplace_chunk_size, :, :],
a,
)
if(need_transpose):
pair_chunk = pair_chunk.transpose(-1, -2)
p[..., i: i + inplace_chunk_size] = pair_chunk
else:
p[..., i: i + inplace_chunk_size, :] = pair_chunk
del pair_chunk
return p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a = compute_projection(z, mask, True, chunked=True)
if(inplace_chunk_size is not None):
n = a.shape[-1]
half_n = n // 2 + n % 2
row_dim = -3
col_dim = -2
b_chunk_dim = row_dim if self._outgoing else col_dim
def empty_slicer(t):
return [slice(None) for _ in t.shape]
def slice_tensor(t, start, end, dim):
# Slices start:end from the dim dimension of t
s = empty_slicer(t)
s[dim] = slice(start, end)
return t[s]
def flip_z_cache_(z_cache, z):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3 = slice_tensor(
z_cache, half_n, None, row_dim
)
z_cache = z_cache.transpose(row_dim, col_dim)
# If n is odd, we need to shrink the z_cache by one row
z_cache = z_cache[..., :(n // 2), :, :]
# Move the 3rd quadrant of z into the
first_half_slicer = empty_slicer(z_cache)
first_half_slicer[col_dim] = slice(0, half_n)
z_cache[first_half_slicer] = quadrant_3
# Get the fourth quadrant of z
quadrant_4 = slice_tensor(z, half_n, None, row_dim)
quadrant_4 = slice_tensor(
quadrant_4, half_n, None, col_dim
)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer = empty_slicer(z_cache)
quadrant_3_slicer[col_dim] = slice(half_n, None)
z_cache[quadrant_3_slicer] = quadrant_4
return z_cache
# Initialize the z cache to the left half of z.
z_cache_shape = list(z.shape)
z_cache_shape[col_dim] = half_n
z_cache = z.new_zeros(z_cache_shape)
z_cache_slicer = empty_slicer(z_cache)
z_cache_slicer[col_dim] = slice(0, half_n)
z_cache.copy_(z[z_cache_slicer])
z_cache_rotated = False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range = list(range(0, half_n, inplace_chunk_size))
initial_offsets = [
i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])
]
after_half = list(range(half_n, n, inplace_chunk_size))
after_half_offsets = [inplace_chunk_size for _ in after_half]
combined_range_with_offsets = zip(
i_range + after_half, initial_offsets + after_half_offsets
)
for i, offset in combined_range_with_offsets:
if(not z_cache_rotated and i >= half_n):
z_cache = flip_z_cache_(z_cache, z)
z_cache_rotated = True
z_chunk_b = slice_tensor(
z, i, i + offset, b_chunk_dim,
)
mask_chunk = slice_tensor(
mask, i, i + offset, b_chunk_dim,
)
z_chunk_b = z_chunk_b.clone()
if(b_chunk_dim == col_dim):
z_chunk_b = slice_tensor(
z, i, i + offset, col_dim
)
else: # b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if(not z_cache_rotated):
z_chunk_slicer = empty_slicer(z_chunk_b)
z_chunk_slicer[col_dim] = slice(0, half_n)
z_chunk_b[z_chunk_slicer] = slice_tensor(
z_cache, i, i + offset, row_dim,
)
else:
z_cache_offset = i - half_n
z_chunk_b = slice_tensor(
z_cache,
z_cache_offset, z_cache_offset + offset,
row_dim
)
b_chunk = compute_projection(
z_chunk_b, mask_chunk, a=False, chunked=False
)
del z_chunk_b
x_chunk = torch.matmul(
a,
b_chunk,
)
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
x_chunk = self.layer_norm_out(x_chunk)
x_chunk = self.linear_z(x_chunk)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g = slice_tensor(
z, i, i + offset, col_dim
)
g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
g_chunk.sigmoid_()
del z_chunk_g
x_chunk *= g_chunk
# Write the columns into z in-place
z_slicer = empty_slicer(z)
z_slicer[col_dim] = slice(i, i + offset)
if(with_add):
z[z_slicer] += x_chunk
else:
z[z_slicer] = x_chunk
else:
b = compute_projection(z, mask, False, False)
x = torch.matmul(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g.sigmoid_()
x *= g
if(with_add):
z += x
else:
z = x
return z
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: Optional[torch.Tensor] = None mask: Optional[torch.Tensor] = None,
_inplace: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -71,57 +369,46 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -71,57 +369,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns: Returns:
[*, N_res, N_res, C_z] output tensor [*, N_res, N_res, C_z] output tensor
""" """
if(_inplace):
x = self._inference_forward(
z,
mask,
inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None: if mask is None:
mask = z.new_ones(z.shape[:-1]) mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z) z = self.layer_norm_in(z)
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) a = mask
a = a * mask a = a * self.sigmoid(self.linear_a_g(z))
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) a = a * self.linear_a_p(z)
b = b * mask b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
x = self._combine_projections(a, b) x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x) x = self.layer_norm_out(x)
x = self.linear_z(x) x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z)) g = self.sigmoid(self.linear_g(z))
z = x * g x = x * g
return z return x
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
""" """
Implements Algorithm 11. Implements Algorithm 11.
""" """
def _combine_projections(self, __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=True)
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
""" """
Implements Algorithm 12. Implements Algorithm 12.
""" """
def _combine_projections(self, __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
...@@ -140,12 +140,20 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -140,12 +140,20 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
residue_index.append(res.id[1]) residue_index.append(res.id[1])
b_factors.append(res_b_factors) b_factors.append(res_b_factors)
parents = None
if("PARENT" in pdb_str):
for l in pdb_str.split("\n"):
if("PARENT" in l and not "N/A" in l):
parents = l.split()[1:]
break
return Protein( return Protein(
atom_positions=np.array(atom_positions), atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask), atom_mask=np.array(atom_mask),
aatype=np.array(aatype), aatype=np.array(aatype),
residue_index=np.array(residue_index), residue_index=np.array(residue_index),
b_factors=np.array(b_factors), b_factors=np.array(b_factors),
parents=parents,
) )
......
...@@ -516,6 +516,9 @@ def run_pipeline( ...@@ -516,6 +516,9 @@ def run_pipeline(
_check_residues_are_well_defined(prot) _check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks) pdb_string = clean_protein(prot, checks=checks)
# We keep the input around to restore metadata deleted by the relaxer
input_prot = prot
exclude_residues = exclude_residues or [] exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues) exclude_residues = set(exclude_residues)
violations = np.inf violations = np.inf
...@@ -532,6 +535,11 @@ def run_pipeline( ...@@ -532,6 +535,11 @@ def run_pipeline(
max_attempts=max_attempts, max_attempts=max_attempts,
use_gpu=use_gpu, use_gpu=use_gpu,
) )
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
ret["min_pdb"] = '\n'.join(['\n'.join(headers), ret["min_pdb"]])
prot = protein.from_pdb_string(ret["min_pdb"]) prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration: if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True) pdb_string = clean_protein(prot, checks=True)
......
...@@ -58,7 +58,8 @@ class ExponentialMovingAverage: ...@@ -58,7 +58,8 @@ class ExponentialMovingAverage:
self._update_state_dict_(model.state_dict(), self.params) self._update_state_dict_(model.state_dict(), self.params)
def load_state_dict(self, state_dict: OrderedDict) -> None: def load_state_dict(self, state_dict: OrderedDict) -> None:
self.params = state_dict["params"] for k in state_dict["params"].keys():
self.params[k] = state_dict["params"][k].clone()
self.decay = state_dict["decay"] self.decay = state_dict["decay"]
def state_dict(self) -> OrderedDict: def state_dict(self) -> OrderedDict:
......
...@@ -43,9 +43,15 @@ def softmax_cross_entropy(logits, labels): ...@@ -43,9 +43,15 @@ def softmax_cross_entropy(logits, labels):
def sigmoid_cross_entropy(logits, labels): def sigmoid_cross_entropy(logits, labels):
log_p = torch.log(torch.sigmoid(logits)) logits_dtype = logits.dtype
log_not_p = torch.log(torch.sigmoid(-logits)) logits = logits.double()
loss = -labels * log_p - (1 - labels) * log_not_p labels = labels.double()
log_p = torch.nn.functional.logsigmoid(logits)
# log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.nn.functional.logsigmoid(-1 * logits)
# log_not_p = torch.log(torch.sigmoid(-logits))
loss = (-1. * labels) * log_p - (1. - labels) * log_not_p
loss = loss.to(dtype=logits_dtype)
return loss return loss
...@@ -1472,13 +1478,13 @@ def experimentally_resolved_loss( ...@@ -1472,13 +1478,13 @@ def experimentally_resolved_loss(
loss = torch.sum(errors * atom37_atom_exists, dim=-1) loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
loss = loss * ( loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution) (resolution >= min_resolution) & (resolution <= max_resolution)
) )
loss = torch.mean(loss) loss = torch.mean(loss)
return loss return loss
......
...@@ -19,6 +19,17 @@ import torch.nn as nn ...@@ -19,6 +19,17 @@ import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
def add(m1, m2, inplace):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if(not inplace):
m1 = m1 + m2
else:
m1 += m2
return m1
def permute_final_dims(tensor: torch.Tensor, inds: List[int]): def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds) zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index]))) first_inds = list(range(len(tensor.shape[:zero_index])))
......
...@@ -110,7 +110,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args): ...@@ -110,7 +110,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
# Prep protein metadata # Prep protein metadata
template_domain_names = [] template_domain_names = []
template_chain_index = None template_chain_index = None
if(feature_processor.config.common.use_templates): if(feature_processor.config.common.use_templates and "template_domain_names" in feature_dict):
template_domain_names = [ template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"] t.decode("utf-8") for t in feature_dict["template_domain_names"]
] ]
......
...@@ -227,7 +227,7 @@ if __name__ == "__main__": ...@@ -227,7 +227,7 @@ if __name__ == "__main__":
) )
add_data_args(parser) add_data_args(parser)
parser.add_argument( parser.add_argument(
"--raise_errors", type=bool, default=False, "--raise_errors", action="store_true", default=False,
help="Whether to crash on parsing errors" help="Whether to crash on parsing errors"
) )
parser.add_argument( parser.add_argument(
......
...@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self): def test_shape(self):
c_z = consts.c_z c_z = consts.c_z
c = 11 c = 11
outgoing = True
tm = TriangleMultiplicationOutgoing( tm = TriangleMultiplicationOutgoing(
c_z, c_z,
c, c,
outgoing,
) )
n_res = consts.c_z n_res = consts.c_z
...@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
_inference_mode=True, _inplace_chunk_size=4,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self): def test_tri_mul_out_compare(self):
...@@ -106,6 +105,40 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -106,6 +105,40 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_tri_mul_in_compare(self): def test_tri_mul_in_compare(self):
self._tri_mul_compare(incoming=True) self._tri_mul_compare(incoming=True)
def _tri_mul_inference_mode(self, incoming=False):
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
)
out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
_inference_mode=False,
).cpu()
# This has to come second because inference mode is in-place
out_inference_mode = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
_inference_mode=True, _inplace_chunk_size=2,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_stock - out_inference_mode)) < consts.eps)
def test_tri_mul_out_inference(self):
self._tri_mul_inference_mode()
def test_tri_mul_in_inference(self):
self._tri_mul_inference_mode(incoming=True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment