Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
......@@ -18,7 +18,7 @@ import glob
import logging
import os
import subprocess
from typing import Any, Mapping, Optional, Sequence
from typing import Any, List, Mapping, Optional, Sequence
from openfold.data.tools import utils
......@@ -99,9 +99,9 @@ class HHBlits:
self.p = p
self.z = z
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = []
......@@ -172,4 +172,4 @@ class HHBlits:
n_iter=self.n_iter,
e_value=self.e_value,
)
return raw_output
return [raw_output]
......@@ -18,8 +18,9 @@ import glob
import logging
import os
import subprocess
from typing import Sequence
from typing import Sequence, Optional
from openfold.data import parsers
from openfold.data.tools import utils
......@@ -62,11 +63,20 @@ class HHSearch:
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str:
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
output_dir = query_tmp_dir if output_dir is None else output_dir
hhr_path = os.path.join(output_dir, "hhsearch_output.hhr")
with open(input_path, "w") as f:
f.write(a3m)
......@@ -104,3 +114,12 @@ class HHSearch:
with open(hhr_path) as f:
hhr = f.read()
return hhr
@staticmethod
def get_template_hits(
output_string: str,
input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool"""
del input_sequence # Used by hmmsearch but not needed for hhsearch
return parsers.parse_hhr(output_string)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import os
import re
import subprocess
from absl import logging
from openfold.data.tools import utils
class Hmmbuild(object):
"""Python wrapper of the hmmbuild binary."""
def __init__(self,
*,
binary_path: str,
singlemx: bool = False):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self.binary_path = binary_path
self.singlemx = singlemx
def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return self._build_profile(sto, model_construction=model_construction)
def build_profile_from_a3m(self, a3m: str) -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines = []
for line in a3m.splitlines():
if not line.startswith('>'):
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
lines.append(line + '\n')
msa = ''.join(lines)
return self._build_profile(msa, model_construction='fast')
def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if model_construction not in {'hand', 'fast'}:
raise ValueError(f'Invalid model_construction {model_construction} - only'
'hand and fast supported.')
with utils.tmpdir_manager() as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
with open(input_query, 'w') as f:
f.write(msa)
cmd = [self.binary_path]
# If adding flags, we have to do so before the output and input:
if model_construction == 'hand':
cmd.append(f'--{model_construction}')
if self.singlemx:
cmd.append('--singlemx')
cmd.extend([
'--amino',
output_hmm_path,
input_query,
])
logging.info('Launching subprocess %s', cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
with utils.timing('hmmbuild query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
if retcode:
raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(output_hmm_path, encoding='utf-8') as f:
hmm = f.read()
return hmm
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import subprocess
from typing import Optional, Sequence
from absl import logging
from openfold.data import parsers
from openfold.data.tools import hmmbuild
from openfold.data.tools import utils
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = ['--F1', '0.1',
'--F2', '0.1',
'--F3', '0.1',
'--incE', '100',
'-E', '100',
'--domE', '100',
'--incdomE', '100']
self.flags = flags
if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}')
@property
def output_format(self) -> str:
return 'sto'
@property
def input_format(self) -> str:
return 'sto'
def query(self, msa_sto: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(
msa_sto,
model_construction='hand'
)
return self.query_with_hmm(hmm, output_dir)
def query_with_hmm(self,
hmm: str,
output_dir: Optional[str] = None
) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
output_dir = query_tmp_dir if output_dir is None else output_dir
out_path = os.path.join(output_dir, 'hmm_output.sto')
with open(hmm_input_path, 'w') as f:
f.write(hmm)
cmd = [
self.binary_path,
'--noali', # Don't include the alignment in stdout.
'--cpu', '8'
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend([
'-A', out_path,
hmm_input_path,
self.database_path,
])
logging.info('Launching sub-process %s', cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'hmmsearch ({os.path.basename(self.database_path)}) query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(out_path) as f:
out_msa = f.read()
return out_msa
@staticmethod
def get_template_hits(
output_string: str,
input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits = parsers.parse_hmmsearch_sto(
output_string,
input_sequence,
)
return template_hits
......@@ -23,6 +23,7 @@ import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from openfold.data import parsers
from openfold.data.tools import utils
......@@ -93,10 +94,13 @@ class Jackhmmer:
self.streaming_callback = streaming_callback
def _query_chunk(
self, input_fasta_path: str, database_path: str
self,
input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None
) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, "output.sto")
# The F1/F2/F3 are the expected proportion to pass each of the filtering
......@@ -167,8 +171,11 @@ class Jackhmmer:
with open(tblout_path) as f:
tbl = f.read()
if(max_sequences is None):
with open(sto_path) as f:
sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
raw_output = dict(
sto=sto,
......@@ -180,10 +187,25 @@ class Jackhmmer:
return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None
) -> Sequence[Sequence[Mapping[str, Any]]]:
return self.query_multiple([input_fasta_path], max_sequences)
def query_multiple(self,
input_fasta_paths: Sequence[str],
max_sequences: Optional[int] = None
) -> Sequence[Sequence[Mapping[str, Any]]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]
single_chunk_results = []
for input_fasta_path in input_fasta_paths:
single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences,
)
single_chunk_results.append(single_chunk_result)
return single_chunk_results
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
......@@ -198,7 +220,7 @@ class Jackhmmer:
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_output = []
chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
......@@ -216,13 +238,21 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i))
for fasta_idx, input_fasta_path in enumerate(input_fasta_paths):
chunked_outputs[fasta_idx].append(
self._query_chunk(
input_fasta_path,
db_local_chunk(i),
max_sequences
)
)
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if(i < self.num_streamed_chunks):
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
return chunked_outputs
......@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)." % (s, len(s))
)
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
......
import os, argparse, pickle, tempfile, concurrent
from openfold.data import parsers
from concurrent.futures import ProcessPoolExecutor
def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file)
with open(path, "r") as infile:
msa = parsers.parse_stockholm(infile.read())
infile.close()
return {file_name: msa}
def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file)
with open(path, "r") as infile:
msa = parsers.parse_a3m(infile.read())
infile.close()
return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
# Number of workers based on the tasks
msa_results={}
a3m_tasks = [(alignment_dir, f) for f in a3m_files]
sto_tasks = [(alignment_dir, f) for f in stockholm_files]
with ProcessPoolExecutor(max_workers = len(a3m_tasks) + len(sto_tasks)) as executor:
a3m_futures = {executor.submit(parse_a3m_file, *task): task for task in a3m_tasks}
sto_futures = {executor.submit(parse_stockholm_file, *task): task for task in sto_tasks}
for future in concurrent.futures.as_completed(a3m_futures | sto_futures):
try:
result = future.result()
msa_results.update(result)
except Exception as exc:
print(f'Task generated an exception: {exc}')
return msa_results
def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
args = parser.parse_args()
alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir)
if all([i.endswith('.sto'), "hmm_output" not in i, "uniprot_hits" not in i])]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile)
print(outfile.name)
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, Optional
from openfold.utils import all_atom_multimer
from openfold.utils.feats import (
pseudo_beta_fn,
dgram_from_positions,
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import add, one_hot
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from openfold.utils import geometry
from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
......@@ -99,12 +113,13 @@ class InputEmbedder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
batch: Dict containing
"target_feat":
Features of shape [*, N_res, tf_dim]
"residue_index":
Features of shape [*, N_res]
"msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
......@@ -139,6 +154,161 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if(self.use_chain_relative):
self.no_bins = (
2 * max_relative_idx + 2 +
1 +
2 * max_relative_chain + 2
)
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = (asym_id[..., None] == asym_id[..., None, :])
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if(self.use_chain_relative):
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) *
torch.ones_like(clipped_offset)
)
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 2, device=final_offset.device
)
rel_pos = one_hot(
final_offset,
boundaries,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
rel_feats.append(entity_id_same[..., None].to(dtype=rel_pos.dtype))
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) *
torch.ones_like(clipped_rel_chain)
)
boundaries = torch.arange(
start=0, end=2 * max_rel_chain + 2, device=final_rel_chain.device
)
rel_chain = one_hot(
final_rel_chain,
boundaries,
)
rel_feats.append(rel_chain)
else:
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 1, device=clipped_offset.device
)
rel_pos = one_hot(
clipped_offset, boundaries,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(
self.linear_relpos.weight.dtype
)
return self.linear_relpos(rel_feat)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class PreembeddingEmbedder(nn.Module):
"""
Embeds the sequence pre-embedding passed to the model and the target_feat features.
......@@ -335,7 +505,7 @@ class RecyclingEmbedder(nn.Module):
return m_update, z_update
class TemplateAngleEmbedder(nn.Module):
class TemplateSingleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
......@@ -355,7 +525,7 @@ class TemplateAngleEmbedder(nn.Module):
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
super(TemplateSingleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
......@@ -459,3 +629,356 @@ class ExtraMSAEmbedder(nn.Module):
x = self.linear(x)
return x
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_single_embedder = TemplateSingleEmbedder(
**config["template_single_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
use_deepspeed_evo_attention=False,
use_lma=False,
inplace_safe=False
):
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_out)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
if (not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out, init='relu')
self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu')
self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu')
self.query_embedding_layer_norm = LayerNorm(c_in)
self.query_embedding_linear = Linear(c_in, c_out, init='relu')
self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu')
self.x_linear = Linear(1, c_out, init='relu')
self.y_linear = Linear(1, c_out, init='relu')
self.z_linear = Linear(1, c_out, init='relu')
self.backbone_mask_linear = Linear(1, c_out, init='relu')
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [(coord * backbone_mask_2d).to(dtype=query_embedding.dtype) for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None].to(dtype=query_embedding.dtype))
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_out: int,
):
super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_out)
self.template_projector = Linear(c_out, c_out)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
dtype = batch["template_all_atom_positions"].dtype
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
).to(dtype=dtype)
template_mask = template_chi_mask[..., 0].to(dtype=dtype)
template_activations = self.template_single_embedder(
template_features
)
template_activations = torch.nn.functional.relu(
template_activations
)
template_activations = self.template_projector(
template_activations,
)
out["template_single_embedding"] = (
template_activations
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
_mask_trans=True,
use_deepspeed_evo_attention=False,
use_lma=False,
inplace_safe=False
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = pseudo_beta_fn(
single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"],
single_template_feats["template_all_atom_mask"])
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
# Vec3Arrays are required to be float32
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos.to(dtype=torch.float32))
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds
......@@ -18,6 +18,7 @@ import torch
import torch.nn as nn
from typing import Tuple, Sequence, Optional
from functools import partial
from abc import ABC, abstractmethod
from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
......@@ -36,6 +37,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
FusedTriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing
)
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer, ChunkSizeTuner
......@@ -117,35 +120,31 @@ class MSATransition(nn.Module):
return m
class EvoformerBlockCore(nn.Module):
class PairStack(nn.Module):
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
pair_dropout: float,
fuse_projection_weights: bool,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
eps: float
):
super(EvoformerBlockCore, self).__init__()
super(PairStack, self).__init__()
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
if fuse_projection_weights:
self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
)
self.outer_product_mean = OuterProductMean(
c_m,
self.tri_mul_in = FusedTriangleMultiplicationIncoming(
c_z,
c_hidden_opm,
c_hidden_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
......@@ -176,64 +175,30 @@ class EvoformerBlockCore(nn.Module):
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
def forward(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
z: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
_attn_chunk_size: Optional[int] = None
) -> torch.Tensor:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
if(_attn_chunk_size is None):
if (_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
m = add(
m,
self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
del m, z
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if(_offload_inference and inplace_safe):
del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
tmu_update = self.tri_mul_out(
z,
mask=pair_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
if (not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
else:
z = tmu_update
......@@ -246,7 +211,7 @@ class EvoformerBlockCore(nn.Module):
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
if (not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
else:
z = tmu_update
......@@ -269,9 +234,8 @@ class EvoformerBlockCore(nn.Module):
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
if (inplace_safe):
z = z.contiguous()
z = add(z,
self.ps_dropout_row_layer(
......@@ -289,9 +253,8 @@ class EvoformerBlockCore(nn.Module):
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
if (inplace_safe):
z = z.contiguous()
z = add(z,
self.pair_transition(
......@@ -300,19 +263,11 @@ class EvoformerBlockCore(nn.Module):
inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
device = z.device
del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
return z
class EvoformerBlock(nn.Module):
class MSABlock(nn.Module, ABC):
@abstractmethod
def __init__(self,
c_m: int,
c_z: int,
......@@ -325,11 +280,14 @@ class EvoformerBlock(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
no_column_attention: bool,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__()
super(MSABlock, self).__init__()
self.opm_first = opm_first
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
......@@ -339,30 +297,127 @@ class EvoformerBlock(nn.Module):
inf=inf,
)
# Specifically, seqemb mode does not use column attention
self.no_column_attention = no_column_attention
if not self.no_column_attention:
self.msa_att_col = MSAColumnAttention(
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_hidden_msa_att,
no_heads_msa,
c_z,
c_hidden_opm,
)
self.pair_stack = PairStack(
c_z=c_z,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
def _compute_opm(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
_offload_inference: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
self.core = EvoformerBlockCore(
c_m=c_m,
m, z = input_tensors
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
return m, z
@abstractmethod
def forward(self,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pass
class EvoformerBlock(MSABlock):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
no_column_attention: bool,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__(c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps)
# Specifically, seqemb mode does not use column attention
self.no_column_attention = no_column_attention
if not self.no_column_attention:
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
eps=eps,
)
def forward(self,
......@@ -380,6 +435,9 @@ class EvoformerBlock(nn.Module):
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
msa_trans_mask = msa_mask if _mask_trans else None
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
......@@ -391,6 +449,15 @@ class EvoformerBlock(nn.Module):
m, z = input_tensors
if self.opm_first:
del m, z
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
......@@ -406,6 +473,14 @@ class EvoformerBlock(nn.Module):
inplace=inplace_safe,
)
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
# Specifically, column attention is not used in seqemb mode.
if not self.no_column_attention:
m = add(m,
......@@ -420,28 +495,64 @@ class EvoformerBlock(nn.Module):
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
m = add(
m,
self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
if not self.opm_first:
if (not inplace_safe):
input_tensors = [m, z]
del m, z
m, z = self.core(
input_tensors,
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
device = input_tensors[0].device
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
if (not inplace_safe):
input_tensors = [m, z]
del m, z
z = self.pair_stack(
z=input_tensors[1],
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
_attn_chunk_size=_attn_chunk_size
)
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].to(device)
m, _ = input_tensors
else:
m = input_tensors[0]
return m, z
class ExtraMSABlock(nn.Module):
class ExtraMSABlock(MSABlock):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
......@@ -460,42 +571,34 @@ class ExtraMSABlock(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
ckpt: bool,
):
super(ExtraMSABlock, self).__init__()
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
super(ExtraMSABlock, self).__init__(c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps)
self.ckpt = ckpt
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
......@@ -525,6 +628,15 @@ class ExtraMSABlock(nn.Module):
m, z = input_tensors
if self.opm_first:
del m, z
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
......@@ -542,15 +654,25 @@ class ExtraMSABlock(nn.Module):
inplace=inplace_safe,
)
if(not inplace_safe):
if (not inplace_safe):
input_tensors = [m, z]
del m, z
def fn(input_tensors):
m = add(input_tensors[0],
m, z = input_tensors
if (_offload_inference and inplace_safe):
# m: GPU, z: CPU
del m, z
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
m = add(m,
self.msa_att_col(
input_tensors[0],
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
......@@ -558,27 +680,63 @@ class ExtraMSABlock(nn.Module):
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
m = add(
m,
self.msa_transition(
m, mask=msa_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
if not self.opm_first:
if (not inplace_safe):
input_tensors = [m, z]
del m
del m, z
m, z = self.core(
input_tensors,
m, z = self._compute_opm(input_tensors=input_tensors,
msa_mask=msa_mask,
chunk_size=chunk_size,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference)
if (_offload_inference and inplace_safe):
# m: CPU, z: GPU
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
device = input_tensors[0].device
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
if (not inplace_safe):
input_tensors = [m, z]
del m, z
z = self.pair_stack(
input_tensors[1],
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
_attn_chunk_size=_attn_chunk_size
)
m = input_tensors[0]
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m
assert (sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].to(device)
m, _ = input_tensors
return m, z
if(torch.is_grad_enabled() and self.ckpt):
if (torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, input_tensors)
else:
......@@ -609,8 +767,10 @@ class EvoformerStack(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
no_column_attention: bool,
opm_first: bool,
fuse_projection_weights: bool,
blocks_per_ckpt: int,
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
......@@ -646,11 +806,18 @@ class EvoformerStack(nn.Module):
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
no_column_attention:
When True, doesn't use column attention. Required for running
sequence embedding mode
opm_first:
When True, Outer Product Mean is performed at the beginning of
the Evoformer block instead of after the MSA Stack.
Used in Multimer pipeline.
fuse_projection_weights:
When True, uses FusedTriangleMultiplicativeUpdate variant in
the Pair Stack. Used in Multimer pipeline.
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
......@@ -678,6 +845,8 @@ class EvoformerStack(nn.Module):
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
no_column_attention=no_column_attention,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps,
)
......@@ -873,6 +1042,8 @@ class ExtraMSAStack(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
fuse_projection_weights: bool,
inf: float,
eps: float,
ckpt: bool,
......@@ -898,6 +1069,8 @@ class ExtraMSAStack(nn.Module):
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
eps=eps,
ckpt=False,
......
......@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
aux_out["ptm_score"] = compute_tm(
tm_logits, **self.config.tm
)
asym_id = outputs.get("asym_id")
if asym_id is not None:
aux_out["iptm_score"] = compute_tm(
tm_logits, asym_id=asym_id, interface=True, **self.config.tm
)
aux_out["weighted_ptm_score"] = (self.config.tm["iptm_weight"] * aux_out["iptm_score"]
+ self.config.tm["ptm_weight"] * aux_out["ptm_score"])
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
......
......@@ -18,11 +18,20 @@ import weakref
import torch
import torch.nn as nn
from openfold.data import data_transforms_multimer
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
dgram_from_positions,
atom14_to_atom37,
)
from openfold.utils.tensor_utils import masked_mean
from openfold.model.embedders import (
InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
TemplateEmbedder,
TemplateEmbedderMultimer,
ExtraMSAEmbedder,
PreembeddingEmbedder,
)
......@@ -75,9 +84,13 @@ class AlphaFold(nn.Module):
self.seqemb_mode = config.globals.seqemb_mode_enabled
# Main trunk + structure module
if self.globals.is_multimer:
self.input_embedder = InputEmbedderMultimer(
**self.config["input_embedder"]
)
elif self.seqemb_mode:
# If using seqemb mode, embed the sequence embeddings passed
# to the model ("preembeddings") instead of embedding the sequence
if self.seqemb_mode:
self.input_embedder = PreembeddingEmbedder(
**self.config["preembedding_embedder"],
)
......@@ -85,25 +98,22 @@ class AlphaFold(nn.Module):
self.input_embedder = InputEmbedder(
**self.config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**self.config["recycling_embedder"],
)
if(self.template_config.enabled):
self.template_angle_embedder = TemplateAngleEmbedder(
**self.template_config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**self.template_config["template_pair_embedder"],
if self.template_config.enabled:
if self.globals.is_multimer:
self.template_embedder = TemplateEmbedderMultimer(
self.template_config,
)
self.template_pair_stack = TemplatePairStack(
**self.template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**self.template_config["template_pointwise_attention"],
else:
self.template_embedder = TemplateEmbedder(
self.template_config,
)
if(self.extra_msa_config.enabled):
if self.extra_msa_config.enabled:
self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"],
)
......@@ -114,113 +124,87 @@ class AlphaFold(nn.Module):
self.evoformer = EvoformerStack(
**self.config["evoformer_stack"],
)
self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**self.config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
self.config["heads"],
)
def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
if(self.template_config.offload_templates):
def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if self.globals.is_multimer:
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if self.template_config.offload_templates:
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif(self.template_config.average_templates):
elif self.template_config.average_templates:
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if(inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.globals.c_t)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
template_embeds = self.template_embedder(
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.template.use_unit_vector,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
if(inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
if(not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=self.globals.use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
_mask_trans=self.config._mask_trans
)
if(inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
ret.update({"template_pair_embedding": t})
del t
return template_embeds
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool:
"""
Early stopping criteria based on criteria used in
AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
Args:
prev_pos: Previous atom positions in atom37/14 representation
next_pos: Current atom positions in atom37/14 representation
mask: 1-D sequence mask
eps: Epsilon used in square root calculation
Returns:
Whether to stop recycling early based on the desired tolerance.
"""
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
def distances(points):
"""Compute all pairwise distances for a set of points."""
d = points[..., None, :] - points[..., None, :, :]
return torch.sqrt(torch.sum(d ** 2, dim=-1))
ret["template_angle_embedding"] = a
if self.config.recycle_early_stop_tolerance < 0:
return False
return ret
ca_idx = residue_constants.atom_order['CA']
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
mask = mask[..., None] * mask[..., None, :]
sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape))))
diff = torch.sqrt(sq_diff + eps).item()
return diff <= self.config.recycle_early_stop_tolerance
def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary
......@@ -229,7 +213,7 @@ class AlphaFold(nn.Module):
# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
if feats[k].dtype == torch.float32:
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input
......@@ -248,18 +232,22 @@ class AlphaFold(nn.Module):
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
## Initialize the SingleSeq and pair representations
if self.globals.is_multimer:
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
elif self.seqemb_mode:
# Initialize the SingleSeq and pair representations
# m: [*, 1, N, C_m]
# z: [*, N, N, C_z]
if self.seqemb_mode:
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["seq_embedding"]
)
else:
## Initialize the MSA and pair representations
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
......@@ -293,12 +281,12 @@ class AlphaFold(nn.Module):
requires_grad=False,
)
x_prev = pseudo_beta_fn(
pseudo_beta_x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
if self.globals.offload_inference and inplace_safe:
m = m.cpu()
z = z.cpu()
......@@ -307,11 +295,13 @@ class AlphaFold(nn.Module):
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
pseudo_beta_x_prev,
inplace_safe=inplace_safe,
)
if(self.globals.offload_inference and inplace_safe):
del pseudo_beta_x_prev
if self.globals.offload_inference and inplace_safe:
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
......@@ -324,15 +314,17 @@ class AlphaFold(nn.Module):
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
......@@ -345,26 +337,40 @@ class AlphaFold(nn.Module):
inplace_safe,
)
if "template_angle_embedding" in template_embeds:
if (
"template_single_embedding" in template_embeds
):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
[m, template_embeds["template_single_embedding"]],
dim=-3
)
# [*, S, N]
if not self.globals.is_multimer:
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
else:
msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]],
dim=-2,
)
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
if self.globals.is_multimer:
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else:
extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
extra_msa_feat = extra_msa_fn(feats).to(dtype=z.dtype)
a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
if self.globals.offload_inference:
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
......@@ -399,7 +405,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
if self.globals.offload_inference:
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
......@@ -455,10 +461,34 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z_prev = outputs["pair"]
early_stop = False
if self.globals.is_multimer:
early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask)
del x_prev
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
return outputs, m_1_prev, z_prev, x_prev, early_stop
def _disable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
for b in self.extra_msa_stack.blocks:
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch):
"""
......@@ -519,13 +549,15 @@ class AlphaFold(nn.Module):
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
early_stop = False
num_recycles = 0
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
is_final_iter = cycle_no == (num_iters - 1) or early_stop
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
# Sidestep AMP bug (PyTorch issue #65766)
......@@ -533,16 +565,25 @@ class AlphaFold(nn.Module):
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
feats,
prevs,
_recycle=(num_iters > 1)
)
if(not is_final_iter):
num_recycles += 1
if not is_final_iter:
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
else:
break
outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
if "asym_id" in batch:
outputs["asym_id"] = feats["asym_id"]
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -131,6 +131,7 @@ class Linear(nn.Linear):
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
precision=None
):
"""
Args:
......@@ -182,6 +183,28 @@ class Linear(nn.Linear):
else:
raise ValueError("Invalid init string.")
self.precision = precision
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.comm.comm.is_initialized()
)
if self.precision is not None:
with torch.cuda.amp.autocast(enabled=False):
bias = self.bias.to(dtype=self.precision) if self.bias is not None else None
return nn.functional.linear(input.to(dtype=self.precision),
self.weight.to(dtype=self.precision),
bias).to(dtype=d)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
bias = self.bias.to(dtype=d) if self.bias is not None else None
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
return nn.functional.linear(input, self.weight, self.bias)
class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
......
......@@ -20,7 +20,7 @@ from operator import mul
import torch
import torch.nn as nn
from typing import Optional, Tuple, Sequence
from typing import Optional, Tuple, Sequence, Union
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import (
......@@ -29,6 +29,9 @@ from openfold.np.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from openfold.utils.geometry.quat_rigid import QuatRigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.vector import Vec3Array, square_euclidean_distance
from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
......@@ -158,6 +161,51 @@ class AngleResnet(nn.Module):
return unnormalized_s, s
class PointProjection(nn.Module):
def __init__(self,
c_hidden: int,
num_points: int,
no_heads: int,
is_multimer: bool,
return_local_points: bool = False,
):
super().__init__()
self.return_local_points = return_local_points
self.no_heads = no_heads
self.num_points = num_points
self.is_multimer = is_multimer
# Multimer requires this to be run with fp32 precision during training
precision = torch.float32 if self.is_multimer else None
self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=precision)
def forward(self,
activations: torch.Tensor,
rigids: Union[Rigid, Rigid3Array],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
if self.is_multimer:
points_local = points_local.view(
points_local.shape[:-1] + (self.no_heads, -1)
)
points_local = torch.split(
points_local, points_local.shape[-1] // 3, dim=-1
)
points_local = torch.stack(points_local, dim=-1).view(out_shape)
points_global = rigids[..., None, None].apply(points_local)
if(self.return_local_points):
return points_global, points_local
return points_global
class InvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
......@@ -172,6 +220,7 @@ class InvariantPointAttention(nn.Module):
no_v_points: int,
inf: float = 1e5,
eps: float = 1e-8,
is_multimer: bool = False,
):
"""
Args:
......@@ -198,22 +247,46 @@ class InvariantPointAttention(nn.Module):
self.no_v_points = no_v_points
self.inf = inf
self.eps = eps
self.is_multimer = is_multimer
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc)
self.linear_kv = Linear(self.c_s, 2 * hc)
self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer))
hpq = self.no_heads * self.no_qk_points * 3
self.linear_q_points = Linear(self.c_s, hpq)
self.linear_q_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
self.is_multimer
)
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
self.linear_kv_points = Linear(self.c_s, hpkv)
if(is_multimer):
self.linear_k = Linear(self.c_s, hc, bias=False)
self.linear_v = Linear(self.c_s, hc, bias=False)
self.linear_k_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
self.is_multimer
)
hpv = self.no_heads * self.no_v_points * 3
self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points,
self.no_heads,
self.is_multimer
)
else:
self.linear_kv = Linear(self.c_s, 2 * hc)
self.linear_kv_points = PointProjection(
self.c_s,
self.no_qk_points + self.no_v_points,
self.no_heads,
self.is_multimer
)
self.linear_b = Linear(self.c_z, self.no_heads)
......@@ -231,8 +304,8 @@ class InvariantPointAttention(nn.Module):
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Rigid,
z: torch.Tensor,
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
......@@ -251,7 +324,7 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
if(_offload_inference and inplace_safe):
if (_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
......@@ -261,41 +334,40 @@ class InvariantPointAttention(nn.Module):
#######################################
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, P_qk]
q_pts = self.linear_q_points(s, r)
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
if(self.is_multimer):
# [*, N_res, H * C_hidden]
k = self.linear_k(s)
v = self.linear_v(s)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# [*, N_res, H, C_hidden]
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_qk, 3]
k_pts = self.linear_k_points(s, r)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H, P_v, 3]
v_pts = self.linear_v_points(s, r)
else:
# [*, N_res, H * 2 * C_hidden]
kv = self.linear_kv(s)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
kv_pts = self.linear_kv_points(s, r)
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
......@@ -308,12 +380,12 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H]
b = self.linear_b(z[0])
if(_offload_inference):
assert(sys.getrefcount(z[0]) == 2)
if (_offload_inference):
assert (sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
if(is_fp16_enabled()):
if (is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
......@@ -330,26 +402,29 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
if(inplace_safe):
if (inplace_safe):
pt_att *= pt_att
else:
pt_att = pt_att ** 2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
)
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
)
if(inplace_safe):
if (inplace_safe):
pt_att *= head_weights
else:
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
......@@ -357,7 +432,7 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
if(inplace_safe):
if (inplace_safe):
a += pt_att
del pt_att
a += square_mask.unsqueeze(-3)
......@@ -384,7 +459,7 @@ class InvariantPointAttention(nn.Module):
o = flatten_final_dims(o, 2)
# [*, H, 3, N_res, P_v]
if(inplace_safe):
if (inplace_safe):
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
o_pt = [
torch.matmul(a, v.to(a.dtype))
......@@ -411,8 +486,9 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
if(_offload_inference):
if (_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z]
......@@ -424,7 +500,233 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
(o, *o_pt, o_pt_norm, o_pair), dim=-1
).to(dtype=z[0].dtype)
)
return s
#TODO: This module follows the refactoring done in IPA for multimer. Running the regular IPA above
# in multimer mode should be equivalent, but tests do not pass unless using this version. Determine
# whether or not the increase in test error matters in practice.
class InvariantPointAttentionMultimer(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(
self,
c_s: int,
c_z: int,
c_hidden: int,
no_heads: int,
no_qk_points: int,
no_v_points: int,
inf: float = 1e5,
eps: float = 1e-8,
is_multimer: bool = True,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
"""
super(InvariantPointAttentionMultimer, self).__init__()
self.c_s = c_s
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.no_qk_points = no_qk_points
self.no_v_points = no_v_points
self.inf = inf
self.eps = eps
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc, bias=False)
self.linear_q_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
is_multimer=True
)
self.linear_k = Linear(self.c_s, hc, bias=False)
self.linear_v = Linear(self.c_s, hc, bias=False)
self.linear_k_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
is_multimer=True
)
self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points,
self.no_heads,
is_multimer=True
)
self.linear_b = Linear(self.c_z, self.no_heads)
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
ipa_point_weights_init_(self.head_weights)
concat_out_dim = self.no_heads * (
self.c_z + self.c_hidden + self.no_v_points * 4
)
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
self.softmax = nn.Softmax(dim=-2)
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
a = 0.
point_variance = (max(self.no_qk_points, 1) * 9.0 / 2)
point_weights = math.sqrt(1.0 / point_variance)
softplus = lambda x: torch.logaddexp(x, torch.zeros_like(x))
head_weights = softplus(self.head_weights)
point_weights = point_weights * head_weights
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H, P_qk]
q_pts = Vec3Array.from_array(self.linear_q_points(s, r))
# [*, N_res, H, P_qk, 3]
k_pts = Vec3Array.from_array(self.linear_k_points(s, r))
pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.)
pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5)
pt_att = pt_att.to(dtype=s.dtype)
a = a + pt_att
scalar_variance = max(self.c_hidden, 1) * 1.
scalar_weights = math.sqrt(1.0 / scalar_variance)
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
k = self.linear_k(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
q = q * scalar_weights
a = a + torch.einsum('...qhc,...khc->...qkh', q, k)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z[0])
if (_offload_inference):
assert (sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
a = a + b
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
a = a + square_mask.unsqueeze(-1)
a = a * math.sqrt(1. / 3) # Normalize by number of logit terms (3)
a = self.softmax(a)
# [*, N_res, H * C_hidden]
v = self.linear_v(s)
# [*, N_res, H, C_hidden]
v = v.view(v.shape[:-1] + (self.no_heads, -1))
o = torch.einsum('...qkh, ...khc->...qhc', a, v)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, N_res, H, P_v, 3]
v_pts = Vec3Array.from_array(self.linear_v_points(s, r))
# [*, N_res, H, P_v]
o_pt = v_pts[..., None, :, :, :] * a.unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# o_pt = Vec3Array(
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3),
# )
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H, P_v]
o_pt = r[..., None].apply_inverse_to_point(o_pt)
o_pt_flat = [o_pt.x, o_pt.y, o_pt.z]
o_pt_flat = [x.to(dtype=a.dtype) for x in o_pt_flat]
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(epsilon=1e-8)
if (_offload_inference):
z[0] = z[0].to(o_pt.x.device)
o_pair = torch.einsum('...ijh, ...ijc->...ihc', a, z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
(o, *o_pt_flat, o_pt_norm, o_pair), dim=-1
).to(dtype=z[0].dtype)
)
......@@ -530,6 +832,7 @@ class StructureModule(nn.Module):
trans_scale_factor,
epsilon,
inf,
is_multimer=False,
**kwargs,
):
"""
......@@ -583,6 +886,7 @@ class StructureModule(nn.Module):
self.trans_scale_factor = trans_scale_factor
self.epsilon = epsilon
self.inf = inf
self.is_multimer = is_multimer
# Buffers to be lazily initialized later
# self.default_frames
......@@ -595,7 +899,8 @@ class StructureModule(nn.Module):
self.linear_in = Linear(self.c_s, self.c_s)
self.ipa = InvariantPointAttention(
ipa = InvariantPointAttention if not self.is_multimer else InvariantPointAttentionMultimer
self.ipa = ipa(
self.c_s,
self.c_z,
self.c_ipa,
......@@ -604,6 +909,7 @@ class StructureModule(nn.Module):
self.no_v_points,
inf=self.inf,
eps=self.epsilon,
is_multimer=self.is_multimer,
)
self.ipa_dropout = nn.Dropout(self.dropout_rate)
......@@ -615,6 +921,9 @@ class StructureModule(nn.Module):
self.dropout_rate,
)
if self.is_multimer:
self.bb_update = QuatRigid(self.c_s, full_quat=False)
else:
self.bb_update = BackboneUpdate(self.c_s)
self.angle_resnet = AngleResnet(
......@@ -625,7 +934,7 @@ class StructureModule(nn.Module):
self.epsilon,
)
def forward(
def _forward_monomer(
self,
evoformer_output_dict,
aatype,
......@@ -661,8 +970,8 @@ class StructureModule(nn.Module):
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if(_offload_inference):
assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
if (_offload_inference):
assert (sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
......@@ -744,7 +1053,102 @@ class StructureModule(nn.Module):
del z, z_reference_list
if(_offload_inference):
if (_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def _forward_multimer(
self,
evoformer_output_dict,
aatype,
mask=None,
inplace_safe=False,
_offload_inference=False,
):
s = evoformer_output_dict["single"]
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
# [*, N, C_s]
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if (_offload_inference):
assert (sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
# [*, N]
rigids = Rigid3Array.identity(
s.shape[:-1],
s.device,
)
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(
s,
z,
rigids,
mask,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids @ self.bb_update(s)
# [*, N, 7, 2]
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
rigids.scale_translation(self.trans_scale_factor),
angles,
aatype,
)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
aatype,
)
preds = {
"frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
}
preds = {k: v.to(dtype=s.dtype) for k, v in preds.items()}
outputs.append(preds)
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if (_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
......@@ -754,6 +1158,34 @@ class StructureModule(nn.Module):
return outputs
def forward(
self,
evoformer_output_dict,
aatype,
mask=None,
inplace_safe=False,
_offload_inference=False,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if(self.is_multimer):
outputs = self._forward_multimer(evoformer_output_dict, aatype, mask, inplace_safe, _offload_inference)
else:
outputs = self._forward_monomer(evoformer_output_dict, aatype, mask, inplace_safe, _offload_inference)
return outputs
def _init_residue_constants(self, float_dtype, device):
if not hasattr(self, "default_frames"):
self.register_buffer(
......@@ -809,7 +1241,7 @@ class StructureModule(nn.Module):
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
self._init_residue_constants(r.dtype, r.device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
......
......@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing,
FusedTriangleMultiplicationIncoming
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import (
......@@ -54,6 +56,7 @@ class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
"""
Args:
......@@ -100,7 +103,6 @@ class TemplatePointwiseAttention(nn.Module):
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
t: torch.Tensor,
z: torch.Tensor,
......@@ -153,6 +155,8 @@ class TemplatePairStackBlock(nn.Module):
no_heads: int,
pair_transition_n: int,
dropout_rate: float,
tri_mul_first: bool,
fuse_projection_weights: bool,
inf: float,
**kwargs,
):
......@@ -165,6 +169,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.inf = inf
self.tri_mul_first = tri_mul_first
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
......@@ -182,6 +187,16 @@ class TemplatePairStackBlock(nn.Module):
inf=inf,
)
if fuse_projection_weights:
self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = FusedTriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
......@@ -196,30 +211,13 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n,
)
def forward(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
):
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
]
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
def tri_att_start_end(self,
single: torch.Tensor,
_attn_chunk_size: Optional[int],
single_mask: torch.Tensor,
use_deepspeed_evo_attention: bool,
use_lma: bool,
inplace_safe: bool):
single = add(single,
self.dropout_row(
self.tri_att_start(
......@@ -248,13 +246,19 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe,
)
return single
def tri_mul_out_in(self,
single: torch.Tensor,
single_mask: torch.Tensor,
inplace_safe: bool):
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
if not inplace_safe:
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
......@@ -267,13 +271,59 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
if not inplace_safe:
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
return single
def forward(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
):
if _attn_chunk_size is None:
_attn_chunk_size = chunk_size
single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
]
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
if self.tri_mul_first:
single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
single_mask=single_mask,
inplace_safe=inplace_safe),
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe)
else:
single = self.tri_mul_out_in(
single=self.tri_att_start_end(single=single,
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe),
single_mask=single_mask,
inplace_safe=inplace_safe)
single = add(single,
self.pair_transition(
single,
......@@ -283,10 +333,10 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe,
)
if(not inplace_safe):
if not inplace_safe:
single_templates[i] = single
if(not inplace_safe):
if not inplace_safe:
z = torch.cat(single_templates, dim=-4)
return z
......@@ -296,6 +346,7 @@ class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def __init__(
self,
c_t,
......@@ -305,6 +356,8 @@ class TemplatePairStack(nn.Module):
no_heads,
pair_transition_n,
dropout_rate,
tri_mul_first,
fuse_projection_weights,
blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9,
......@@ -341,6 +394,8 @@ class TemplatePairStack(nn.Module):
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
)
self.blocks.append(block)
......@@ -349,7 +404,7 @@ class TemplatePairStack(nn.Module):
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
if tune_chunk_size:
self.chunk_size_tuner = ChunkSizeTuner()
def forward(
......@@ -371,7 +426,7 @@ class TemplatePairStack(nn.Module):
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
if mask.shape[-3] == 1:
expand_idx = list(mask.shape)
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
......@@ -389,8 +444,8 @@ class TemplatePairStack(nn.Module):
for b in self.blocks
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
if chunk_size is not None and self.chunk_size_tuner is not None:
assert (not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t.clone(),),
......@@ -478,7 +533,7 @@ def embed_templates_offload(
_mask_trans=model.config._mask_trans,
)
assert(sys.getrefcount(t) == 2)
assert (sys.getrefcount(t) == 2)
pair_embeds_cpu.append(t.cpu())
......@@ -504,7 +559,7 @@ def embed_templates_offload(
del pair_chunks
if(inplace_safe):
if inplace_safe:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
......@@ -516,9 +571,9 @@ def embed_templates_offload(
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
a = model.template_single_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": t})
......@@ -605,19 +660,19 @@ def embed_templates_average(
)
denom = math.ceil(n_templ / templ_group_size)
if(inplace_safe):
if inplace_safe:
t /= denom
else:
t = t / denom
if(inplace_safe):
if inplace_safe:
out_tensor += t
else:
out_tensor = out_tensor + t
del t
if(inplace_safe):
if inplace_safe:
out_tensor *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
out_tensor = out_tensor * (torch.sum(batch["template_mask"], dim=-1) > 0)
......@@ -629,9 +684,9 @@ def embed_templates_average(
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
a = model.template_single_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": out_tensor})
......
......@@ -15,6 +15,7 @@
from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
......@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
@abstractmethod
def __init__(self, c_z, c_hidden, _outgoing):
"""
Args:
c_z:
......@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__()
super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
......@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
return permute_final_dims(p, (1, 2, 0))
@abstractmethod
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
......@@ -397,7 +435,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
# reduced-precision modes
a_std = a.std()
b_std = b.std()
if(a_std != 0. and b_std != 0.):
if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std()
b = b / b.std()
......@@ -428,3 +466,152 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
Implements Algorithm 12.
"""
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(FusedTriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_ab_p = Linear(self.c_z, self.c_hidden * 2)
self.linear_ab_g = Linear(self.c_z, self.c_hidden * 2, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
_inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask):
p = self.linear_ab_g(pair)
p.sigmoid_()
p *= self.linear_ab_p(pair)
p *= mask
return p
def compute_projection(pair, mask):
p = compute_projection_helper(pair, mask)
left = p[..., :self.c_hidden]
right = p[..., self.c_hidden:]
return left, right
z_norm_in = self.layer_norm_in(z)
a, b = compute_projection(z_norm_in, mask)
x = self._combine_projections(a, b, _inplace_chunk_size=_inplace_chunk_size)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z_norm_in)
g.sigmoid_()
x *= g
if (with_add):
z += x
else:
z = x
return z
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if (inplace_safe):
x = self._inference_forward(
z,
mask,
_inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
ab = mask
ab = ab * self.sigmoid(self.linear_ab_g(z))
ab = ab * self.linear_ab_p(z)
a = ab[..., :self.c_hidden]
b = ab[..., self.c_hidden:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a_std = a.std()
b_std = b.std()
if (is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std()
b = b / b.std()
if (is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
x = x * g
return x
class FusedTriangleMultiplicationOutgoing(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=True)
class FusedTriangleMultiplicationIncoming(FusedTriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
__init__ = partialmethod(FusedTriangleMultiplicativeUpdate.__init__, _outgoing=False)
......@@ -36,6 +36,11 @@ FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert(PDB_MAX_CHAINS == 62)
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
......@@ -73,6 +78,13 @@ class Protein:
# Chain corresponding to each parent
parents_chain_index: Optional[Sequence[int]] = None
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format"
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
......@@ -108,6 +120,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
......@@ -132,6 +145,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
......@@ -224,6 +238,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}'
)
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
pdb_headers = []
......@@ -316,21 +338,46 @@ def to_pdb(prot: Protein) -> str:
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index
chain_index = prot.chain_index.astype(np.int32)
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
headers = get_pdb_headers(prot)
if(len(headers) > 0):
if (len(headers) > 0):
pdb_lines.extend(headers)
pdb_lines.append("MODEL 1")
n = aatype.shape[0]
atom_index = 1
last_chain_index = chain_index[0]
prev_chain_index = 0
chain_tags = string.ascii_uppercase
# Add all atom sites.
for i in range(n):
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1]
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
......@@ -355,6 +402,8 @@ def to_pdb(prot: Protein) -> str:
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
#TODO: check this refactor, chose main branch version
#f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{res_name_3:>3} {chain_tag:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
......@@ -386,9 +435,12 @@ def to_pdb(prot: Protein) -> str:
# each new chain.
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines)
# Pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def to_modelcif(prot: Protein) -> str:
......@@ -539,7 +591,7 @@ def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
chain_index: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True,
remark: Optional[str] = None,
parents: Optional[Sequence[str]] = None,
parents_chain_index: Optional[Sequence[int]] = None
......@@ -550,20 +602,32 @@ def from_prediction(
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"]) - 1
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
)
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
b_factors=b_factors,
chain_index=chain_index,
remark=remark,
......
......@@ -563,60 +563,3 @@ def run_pipeline(
)
iteration += 1
return ret
def get_initial_energies(
pdb_strs: Sequence[str],
stiffness: float = 0.0,
restraint_set: str = "non_hydrogen",
exclude_residues: Optional[Sequence[int]] = None,
):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues = exclude_residues or []
openmm_pdbs = [
openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
]
force_field = openmm_app.ForceField("amber99sb.xml")
system = force_field.createSystem(
openmm_pdbs[0].topology, constraints=openmm_app.HBonds
)
stiffness = stiffness * ENERGY / (LENGTH ** 2)
if stiffness > 0 * ENERGY / (LENGTH ** 2):
_add_restraints(
system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
)
simulation = openmm_app.Simulation(
openmm_pdbs[0].topology,
system,
openmm.LangevinIntegrator(0, 0.01, 0.0),
openmm.Platform.getPlatformByName("CPU"),
)
energies = []
for pdb in openmm_pdbs:
try:
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True)
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
except Exception as e: # pylint: disable=broad-except
logging.error(
"Error getting initial energy, returning large value %s", e
)
energies.append(unit.Quantity(1e20, ENERGY))
return energies
......@@ -79,7 +79,7 @@ def assert_equal_nonterminal_atom_types(
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order["OXT"]
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
......
......@@ -17,14 +17,13 @@
import collections
import functools
import os
from typing import Mapping, List, Tuple
from importlib import resources
import numpy as np
import tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
......@@ -450,9 +449,9 @@ def load_stereo_chemical_props() -> Tuple[
("residue_virtual_bonds").
Returns:
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
residue_bonds: Dict that maps resname -> list of Bond tuples
residue_virtual_bonds: Dict that maps resname -> list of Bond tuples
residue_bond_angles: Dict that maps resname -> list of BondAngle tuples
"""
# TODO: this file should be downloaded in a setup script
stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
......@@ -1310,3 +1309,179 @@ def aatype_to_str_sequence(aatype):
restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
### ALPHAFOLD MULTIMER STUFF ###
def _make_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in restypes:
residue_name = restype_1to3[residue_name]
residue_chi_angles = chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return np.array(chi_atom_indices)
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
restype_1to3[res] for res in restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
def _make_restype_atom37_mask():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
return restype_atom37_mask
def _make_restype_atom14_mask():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask = []
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_mask
def _make_restype_atom37_to_atom14():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in atom_types
])
restype_atom37_to_atom14.append([0] * 37)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
return restype_atom37_to_atom14
def _make_restype_atom14_to_atom37():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_to_atom37.append([
(atom_order[name] if name else 0)
for name in atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
return restype_atom14_to_atom37
def _make_restype_atom14_is_ambiguous():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = restype_order[
restype_3to1[resname]]
atom_idx1 = restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
def _make_restype_rigidgroup_base_atom37_idx():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for chi_idx in range(4):
if chi_angles_mask[restype][chi_idx]:
atom_names = chi_angles_atoms[resname][chi_idx]
base_atom_names[restype, chi_idx + 4, :] = atom_names[1:]
# Translate atom names into atom37 indices.
lookuptable = atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
base_atom_names)
return restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES = _make_chi_atom_indices()
RENAMING_MATRICES = _make_renaming_matrices()
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
RESTYPE_ATOM37_MASK = _make_restype_atom37_mask()
RESTYPE_ATOM14_MASK = _make_restype_atom14_mask()
RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32)
RESTYPE_RIGIDGROUP_MASK[:, 0] = 1
RESTYPE_RIGIDGROUP_MASK[:, 3] = 1
RESTYPE_RIGIDGROUP_MASK[:20, 4:] = chi_angles_mask
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from functools import partial
from typing import Dict, Text, Tuple
import torch
from openfold.np import residue_constants as rc
from openfold.utils import geometry, tensor_utils
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
import numpy as np
def squared_difference(x, y):
return np.square(x - y)
def get_rc_tensor(rc_np, aatype):
return torch.tensor(rc_np, device=aatype.device)[aatype]
def atom14_to_atom37(
atom14_data: torch.Tensor, # (*, N, 14, ...)
aatype: torch.Tensor # (*, N)
) -> Tuple: # (*, N, 37, ...)
"""Convert atom14 to atom37 representation."""
idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype).long()
no_batch_dims = len(aatype.shape) - 1
atom37_data = tensor_utils.batched_gather(
atom14_data,
idx_atom37_to_atom14,
dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1
)
atom37_mask = get_rc_tensor(rc.RESTYPE_ATOM37_MASK, aatype)
if len(atom14_data.shape) == no_batch_dims + 2:
atom37_data *= atom37_mask
elif len(atom14_data.shape) == no_batch_dims + 3:
atom37_data *= atom37_mask[..., None].to(dtype=atom37_data.dtype)
else:
raise ValueError("Incorrectly shaped data")
return atom37_data, atom37_mask
def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37 = get_rc_tensor(
rc.RESTYPE_ATOM14_TO_ATOM37, aatype
)
no_batch_dims = len(aatype.shape)
atom14_mask = tensor_utils.batched_gather(
all_atom_mask,
residx_atom14_to_atom37,
dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1,
).to(all_atom_pos.dtype)
# create a mask for known groundtruth positions
atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype)
# gather the groundtruth positions
atom14_positions = tensor_utils.batched_gather(
all_atom_pos,
residx_atom14_to_atom37,
dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1,
),
atom14_positions = atom14_mask * atom14_positions
return atom14_positions, atom14_mask
def get_alt_atom14(aatype, positions: torch.Tensor, mask):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform = get_rc_tensor(rc.RENAMING_MATRICES, aatype)
alternative_positions = torch.sum(
positions[..., None, :] * renaming_transform[..., None],
dim=-2
)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask = torch.sum(mask[..., None] * renaming_transform, dim=-2)
return alternative_positions, alternative_mask
def atom37_to_frames(
aatype: torch.Tensor, # (...)
all_atom_positions: torch.Tensor, # (..., 37)
all_atom_mask: torch.Tensor, # (..., 37)
) -> Dict[Text, torch.Tensor]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
no_batch_dims = len(aatype.shape) - 1
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx = get_rc_tensor(
rc.RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype
)
# Gather the base atom positions for each rigid group.
base_atom_pos = tensor_utils.batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim = no_batch_dims + 1,
batch_dims = no_batch_dims + 1,
)
# Compute the Rigids.
point_on_neg_x_axis = base_atom_pos[..., :, :, 0]
origin = base_atom_pos[..., :, :, 1]
point_on_xy_plane = base_atom_pos[..., :, :, 2]
gt_rotation = geometry.Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin
)
gt_frames = geometry.Rigid3Array(gt_rotation, origin)
# Compute a mask whether the group exists.
# (N, 8)
group_exists = get_rc_tensor(rc.RESTYPE_RIGIDGROUP_MASK, aatype)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist = tensor_utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.to(dtype=all_atom_positions.dtype),
residx_rigidgroup_base_atom37_idx,
batch_dims=no_batch_dims + 1,
)
gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=all_atom_positions.dtype), [8, 1, 1])
rots[0, 0, 0] = -1
rots[0, 2, 2] = -1
gt_frames = gt_frames.compose_rotation(
geometry.Rot3Array.from_array(
torch.tensor(rots, device=aatype.device)
)
)
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=all_atom_positions.dtype)
restype_rigidgroup_rots = np.tile(
np.eye(3, dtype=all_atom_positions.dtype), [21, 8, 1, 1]
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous = torch.tensor(
restype_rigidgroup_is_ambiguous,
device=aatype.device,
)[aatype]
ambiguity_rot = torch.tensor(
restype_rigidgroup_rots,
device=aatype.device,
)[aatype]
ambiguity_rot = geometry.Rot3Array.from_array(
torch.Tensor(ambiguity_rot, device=aatype.device)
)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
fix_shape = lambda x: x.reshape(x.shape[:-2] + (8,))
# reshape back to original residue layout
gt_frames = fix_shape(gt_frames)
gt_exists = fix_shape(gt_exists)
group_exists = fix_shape(group_exists)
residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous)
alt_gt_frames = fix_shape(alt_gt_frames)
return {
'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8)
'rigidgroups_gt_exists': gt_exists, # (..., 8)
'rigidgroups_group_exists': group_exists, # (..., 8)
'rigidgroups_group_is_ambiguous':
residx_rigidgroup_is_ambiguous, # (..., 8)
'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8)
}
def torsion_angles_to_frames(
aatype: torch.Tensor, # (N)
backb_to_global: geometry.Rigid3Array, # (N)
torsion_angles_sin_cos: torch.Tensor # (N, 7, 2)
) -> geometry.Rigid3Array: # (N, 8)
"""Compute rigid group frames from torsion angles."""
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m = get_rc_tensor(rc.restype_rigid_group_default_frame, aatype)
default_frames = geometry.Rigid3Array.from_array4x4(m)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles = torsion_angles_sin_cos[..., 0]
cos_angles = torsion_angles_sin_cos[..., 1]
# insert zero rotation for backbone group.
num_residues = aatype.shape[-1]
sin_angles = torch.cat(
[
torch.zeros_like(aatype).unsqueeze(dim=-1),
sin_angles,
],
dim=-1)
cos_angles = torch.cat(
[
torch.ones_like(aatype).unsqueeze(dim=-1),
cos_angles
],
dim=-1
)
zeros = torch.zeros_like(sin_angles)
ones = torch.ones_like(sin_angles)
# all_rots are geometry.Rot3Array with shape (..., N, 8)
all_rots = geometry.Rot3Array(
ones, zeros, zeros,
zeros, cos_angles, -sin_angles,
zeros, sin_angles, cos_angles
)
# Apply rotations to the frames.
all_frames = default_frames.compose_rotation(all_rots)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb = all_frames[..., 4]
chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[..., 5]
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[..., 6]
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[..., 7]
all_frames_to_backb = Rigid3Array.cat(
[
all_frames[..., 0:5],
chi2_frame_to_backb[..., None],
chi3_frame_to_backb[..., None],
chi4_frame_to_backb[..., None]
],
dim=-1
)
# Create the global frames.
# shape (N, 8)
all_frames_to_global = backb_to_global[..., None] @ all_frames_to_backb
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
aatype: torch.Tensor, # (*, N)
all_frames_to_global: geometry.Rigid3Array # (N, 8)
) -> geometry.Vec3Array: # (*, N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx = get_rc_tensor(
rc.restype_atom14_to_rigid_group,
aatype
)
group_mask = torch.nn.functional.one_hot(
residx_to_group_idx,
num_classes=8
) # shape (*, N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global = all_frames_to_global[..., None, :] * group_mask
map_atoms_to_global = map_atoms_to_global.map_tensor_fn(
partial(torch.sum, dim=-1)
)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions = geometry.Vec3Array.from_array(
get_rc_tensor(
rc.restype_atom14_rigid_group_positions,
aatype
)
)
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions = map_atoms_to_global.apply_to_point(lit_positions)
# Mask out non-existing atoms.
mask = get_rc_tensor(rc.restype_atom14_mask, aatype)
pred_positions = pred_positions * mask
return pred_positions
def extreme_ca_ca_distance_violations(
positions: geometry.Vec3Array, # (N, 37(14))
mask: torch.Tensor, # (N, 37(14))
residue_index: torch.Tensor, # (N)
max_angstrom_tolerance=1.5,
eps: float = 1e-6
) -> torch.Tensor:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos = positions[..., :-1, 1] # (N - 1,)
this_ca_mask = mask[..., :-1, 1] # (N - 1)
next_ca_pos = positions[..., 1:, 1] # (N - 1,)
next_ca_mask = mask[..., 1:, 1] # (N - 1)
has_no_gap_mask = (
(residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
).astype(positions.x.dtype)
ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, eps)
violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
return tensor_utils.masked_mean(mask=mask, value=violations, dim=-1)
def get_chi_atom_indices(device: torch.device):
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[rc.atom_order[atom] for atom in chi_angle]
)
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return torch.tensor(chi_atom_indices, device=device)
def compute_chi_angles(
positions: geometry.Vec3Array,
mask: torch.Tensor,
aatype: torch.Tensor
):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, rc.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, rc.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert positions.shape[-1] == rc.atom_type_num
assert mask.shape[-1] == rc.atom_type_num
no_batch_dims = len(aatype.shape) - 1
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices = get_chi_atom_indices(aatype.device)
# DISCREPANCY: DeepMind doesn't remove the gaps here. I don't know why
# theirs works.
aatype_gapless = torch.clamp(aatype, max=20)
# Select atoms to compute chis. Shape: [*, num_res, chis=4, atoms=4].
atom_indices = chi_atom_indices[aatype_gapless]
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms = positions.map_tensor_fn(
partial(
tensor_utils.batched_gather,
inds=atom_indices,
dim=-1,
no_batch_dims=no_batch_dims + 1
)
)
a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)]
chi_angles = geometry.dihedral_angle(a, b, c, d)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = torch.tensor(chi_angles_mask, device=aatype.device)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask = chi_angles_mask[aatype_gapless]
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask = tensor_utils.batched_gather(
mask,
atom_indices,
dim=-1,
no_batch_dims=no_batch_dims + 1
)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1)
chi_mask = chi_mask * chi_angle_atoms_mask.to(chi_angles.dtype)
return chi_angles, chi_mask
def make_transform_from_reference(
a_xyz: geometry.Vec3Array,
b_xyz: geometry.Vec3Array,
c_xyz: geometry.Vec3Array
) -> geometry.Rigid3Array:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation = geometry.Rot3Array.from_two_vectors(
c_xyz - b_xyz,
a_xyz - b_xyz
)
return geometry.Rigid3Array(rotation, b_xyz)
def make_backbone_affine(
positions: geometry.Vec3Array,
mask: torch.Tensor,
aatype: torch.Tensor,
) -> Tuple[geometry.Rigid3Array, torch.Tensor]:
a = rc.atom_order['N']
b = rc.atom_order['CA']
c = rc.atom_order['C']
rigid_mask = (mask[..., a] * mask[..., b] * mask[..., c])
rigid = make_transform_from_reference(
a_xyz=positions[..., a],
b_xyz=positions[..., b],
c_xyz=positions[..., c],
)
return rigid, rigid_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment