Commit c1129bef authored by Christina Floristean's avatar Christina Floristean
Browse files

Fixed bug in triangle multiplicative update and added early stop recycling.

parent 425bdb5e
......@@ -155,21 +155,38 @@ def model_config(
c.loss.tm.weight = 0.1
elif "multimer" in name:
c.globals.is_multimer = True
c.globals.bfloat16 = True
c.globals.bfloat16_output = False
c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20
for k,v in multimer_model_config_update.items():
c.model[k] = v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
c.model.evoformer.num_msa = 252
c.model.evoformer.num_extra_msa= 1152
c.model.evoformer.fuse_projection_weights = False
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_msa_clusters = 252
c.data.predict.max_msa_clusters = 252
c.data.train.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
c.model.evoformer_stack.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
elif name == 'model_5_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
for k,v in multimer_model_config_update.items():
c.model[k] = v
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
else:
c.data.train.max_msa_clusters = 508
c.data.predict.max_msa_clusters = 508
c.data.train.max_extra_msa = 2048
c.data.predict.max_extra_msa = 2048
c.data.common.unsupervised_features.extend([
"msa_mask",
......@@ -646,6 +663,12 @@ config = mlc.ConfigDict(
"eps": eps,
},
"ema": {"decay": 0.999},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1
}
)
......@@ -653,6 +676,7 @@ multimer_model_config_update = {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
......@@ -703,6 +727,7 @@ multimer_model_config_update = {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
#"num_extra_msa": 2048
},
"extra_msa_stack": {
"c_m": c_e,
......@@ -788,5 +813,5 @@ multimer_model_config_update = {
"c_out": 37,
},
},
"recycle_early_stop_tolerance": 0.5
}
......@@ -280,6 +280,7 @@ def run_msa_tool(
else:
result = msa_runner.query(fasta_path)[0]
assert msa_out_path.split('.')[-1] == msa_format
with open(msa_out_path, "w") as f:
f.write(result[msa_format])
......@@ -321,6 +322,7 @@ def make_sequence_features_with_custom_template(
**template_features.features
}
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
def __init__(
......@@ -372,6 +374,8 @@ class AlignmentRunner:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
uniprot_max_hits:
Max number of uniprot hits
"""
db_map = {
"jackhmmer": {
......@@ -468,7 +472,7 @@ class AlignmentRunner:
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.sto")
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
......@@ -505,7 +509,7 @@ class AlignmentRunner:
)
if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
fasta_path=fasta_path,
......@@ -719,16 +723,14 @@ class DataPipeline:
msa = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename):
msa = parsers.parse_stockholm(read_msa(start, size))
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else:
continue
msa_data[name] = data
msa_data[name] = msa
fp.close()
else:
......@@ -739,17 +741,15 @@ class DataPipeline:
if(ext == ".a3m"):
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
elif(ext == ".sto" and not "hmm_output" == filename):
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else:
continue
msa_data[f] = data
msa_data[f] = msa
return msa_data
......@@ -831,8 +831,6 @@ class DataPipeline:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
......@@ -849,16 +847,11 @@ class DataPipeline:
)
deletion_matrix = [[0 for _ in input_sequence]]
msa_data["dummy"] = {
"msa": parsers.Msa(sequences=input_sequence, deletion_matrix=deletion_matrix, descriptions=None),
"deletion_matrix": deletion_matrix,
}
msa_data["dummy"] = parsers.Msa(sequences=input_sequence,
deletion_matrix=deletion_matrix,
descriptions=None)
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
return msas, deletion_matrices
return list(msa_data.values())
def _process_msa_feats(
self,
......@@ -866,7 +859,7 @@ class DataPipeline:
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
msas = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
msa_features = make_msa_features(
......@@ -944,7 +937,6 @@ class DataPipeline:
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
alignment_index)
template_features = make_template_features(
......@@ -994,7 +986,6 @@ class DataPipeline:
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
alignment_index
)
......@@ -1080,11 +1071,11 @@ class DataPipeline:
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
msas = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
msa_list.append([m.sequences for m in msas])
deletion_mat_list.append([m.deletion_matrix for m in msas])
final_msa = []
final_deletion_mat = []
......@@ -1181,12 +1172,10 @@ class DataPipelineMultimer:
def _all_seq_msa_features(self, fasta_path, alignment_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
#TODO: Quick fix, change back to .sto after parsing fixed
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.a3m")
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_a3m(uniprot_msa_string)
#msa = parsers.parse_stockholm(uniprot_msa_string)
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
......
......@@ -902,7 +902,7 @@ def _process_single_hit(
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.sum_probs if hit.sum_probs else 0.,
hit.index,
str(e),
parsing_result.errors,
......@@ -919,7 +919,7 @@ def _process_single_hit(
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.sum_probs if hit.sum_probs else 0.,
hit.index,
str(e),
parsing_result.errors,
......
......@@ -525,16 +525,14 @@ class EvoformerBlock(MSABlock):
_attn_chunk_size=_attn_chunk_size
)
m = input_tensors[0]
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
m, _ = input_tensors
else:
m = input_tensors[0]
return m, z
......@@ -713,12 +711,10 @@ class ExtraMSABlock(MSABlock):
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
del m
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
m, _ = input_tensors
return m, z
......
......@@ -25,6 +25,7 @@ from openfold.utils.feats import (
dgram_from_positions,
atom14_to_atom37,
)
from openfold.utils.tensor_utils import masked_mean
from openfold.model.embedders import (
InputEmbedder,
InputEmbedderMultimer,
......@@ -165,6 +166,38 @@ class AlphaFold(nn.Module):
return template_embeds
def tolerance_reached(self, prev_pos, next_pos, mask, no_batch_dims, eps=1e-8) -> bool:
"""
Early stopping criteria based on criteria used in
AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
Args:
prev_pos: Previous atom positions in atom37/14 representation
next_pos: Current atom positions in atom37/14 representation
mask: 1-D sequence mask
eps: Epsilon used in square root calculation
Returns:
Whether to stop recycling early based on the desired tolerance.
"""
def distances(points):
"""Compute all pairwise distances for a set of points."""
d = points[..., None, :] - points[..., None, :, :]
return torch.sqrt(torch.sum(d ** 2, dim=-1))
if self.config.recycle_early_stop_tolerance < 0:
return False
if no_batch_dims == 0:
prev_pos = prev_pos.unsqueeze(dim=0)
next_pos = next_pos.unsqueeze(dim=0)
mask = mask.unsqueeze(dim=0)
ca_idx = residue_constants.atom_order['CA']
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
mask = mask[..., None] * mask[..., None, :]
sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape))))
diff = torch.sqrt(sq_diff + eps)
return diff <= self.config.recycle_early_stop_tolerance
def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary
outputs = {}
......@@ -263,7 +296,7 @@ class AlphaFold(nn.Module):
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
......@@ -406,10 +439,16 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z_prev = outputs["pair"]
early_stop = False
if self.globals.is_multimer:
early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask, no_batch_dims)
del x_prev
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
return outputs, m_1_prev, z_prev, x_prev, early_stop
def _disable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = None
......@@ -488,13 +527,14 @@ class AlphaFold(nn.Module):
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
early_stop = False
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
is_final_iter = cycle_no == (num_iters - 1) or early_stop
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
# Sidestep AMP bug (PyTorch issue #65766)
......@@ -502,16 +542,18 @@ class AlphaFold(nn.Module):
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
feats,
prevs,
_recycle=(num_iters > 1)
)
if(not is_final_iter):
if not is_final_iter:
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
else:
break
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -509,7 +509,6 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask):
pair = self.layer_norm_in(pair)
p = self.linear_ab_g(pair)
p.sigmoid_()
p *= self.linear_ab_p(pair)
......@@ -519,16 +518,21 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
def compute_projection(pair, mask):
p = compute_projection_helper(pair, mask)
a = p[..., :self.c_hidden]
b = p[..., self.c_hidden:]
if self._outgoing:
left = p[..., :self.c_hidden]
right = p[..., self.c_hidden:]
else:
left = p[..., self.c_hidden:]
right = p[..., :self.c_hidden]
return a, b
return left, right
a, b = compute_projection(z, mask)
z_norm_in = self.layer_norm_in(z)
a, b = compute_projection(z_norm_in, mask)
x = self._combine_projections(a, b, _inplace_chunk_size=_inplace_chunk_size)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g = self.linear_g(z_norm_in)
g.sigmoid_()
x *= g
if (with_add):
......@@ -573,8 +577,12 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
ab = ab * self.sigmoid(self.linear_ab_g(z))
ab = ab * self.linear_ab_p(z)
if self._outgoing:
a = ab[..., :self.c_hidden]
b = ab[..., self.c_hidden:]
else:
b = ab[..., :self.c_hidden]
a = ab[..., self.c_hidden:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
......
import argparse
import logging
import os
import string
from collections import defaultdict
from openfold.data import mmcif_parsing
from openfold.np import protein, residue_constants
......@@ -22,7 +23,7 @@ def main(args):
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {fname}...')
if(args.raise_errors):
raise list(mmcif.errors.values())[0]
raise Exception(list(mmcif.errors.values())[0])
else:
continue
......@@ -31,6 +32,25 @@ def main(args):
chain_id = '_'.join([basename, chain])
fasta.append(f">{chain_id}")
fasta.append(seq)
elif(ext == ".pdb"):
with open(fpath, 'r') as fp:
pdb_str = fp.read()
protein_object = protein.from_pdb_string(pdb_str)
aatype = protein_object.aatype
chain_index = protein_object.chain_index
last_chain_index = chain_index[0]
chain_dict = defaultdict(list)
for i in range(aatype.shape[0]):
chain_dict[chain_index[i]].append(residue_constants.restypes_with_x[aatype[i]])
chain_tags = string.ascii_uppercase
for chain, seq in chain_dict.items():
chain_id = '_'.join([basename, chain_tags[chain]])
fasta.append(f">{chain_id}")
fasta.append(''.join(seq))
elif(ext == ".core"):
with open(fpath, 'r') as fp:
core_str = fp.read()
......
......@@ -26,7 +26,12 @@ mkdir -p "${ALIGNMENT_DIR}"
for chain_dir in $(ls "${RODA_DIR}"); do
CHAIN_DIR_PATH="${RODA_DIR}/${chain_dir}"
for subdir in $(ls "${CHAIN_DIR_PATH}"); do
if [[ $subdir = "pdb" ]] || [[ $subdir = "cif" ]]; then
if [[ ! -d "$subdir" ]]; then
echo "$subdir is not directory"
continue
elif [[ -z $(ls "${subdir}")]]; then
continue
elif [[ $subdir = "pdb" ]] || [[ $subdir = "cif" ]]; then
mv "${CHAIN_DIR_PATH}/${subdir}"/* "${DATA_DIR}"
else
CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_dir}"
......
......@@ -4,10 +4,11 @@ import json
import logging
from multiprocessing import Pool
import os
import string
import sys
sys.path.append(".") # an innocent hack to get this to run from the top level
from collections import defaultdict
from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
......@@ -49,20 +50,27 @@ def parse_file(
pdb_string = fp.read()
protein_object = protein.from_pdb_string(pdb_string, None)
aatype = protein_object.aatype
chain_index = protein_object.chain_index
chain_dict = {}
chain_dict["seq"] = residue_constants.aatype_to_str_sequence(
protein_object.aatype,
)
chain_dict["resolution"] = 0.
chain_dict = defaultdict(list)
for i in range(aatype.shape[0]):
chain_dict[chain_index[i]].append(residue_constants.restypes_with_x[aatype[i]])
out = {}
chain_tags = string.ascii_uppercase
for chain, seq in chain_dict.items():
full_name = "_".join([file_id, chain_tags[chain]])
out[full_name] = {}
local_data = out[full_name]
local_data["resolution"] = 0.
local_data["seq"] = ''.join(seq)
if(chain_cluster_size_dict is not None):
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
chain_dict["cluster_size"] = cluster_size
out = {file_id: chain_dict}
local_data["cluster_size"] = cluster_size
return out
......
......@@ -40,7 +40,8 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
alignment_runner.run(
fasta_path, alignment_dir
)
except:
except Exception as e:
logging.warning(e)
logging.warning(f"Failed to run alignments for {first_name}. Skipping...")
os.remove(fasta_path)
os.rmdir(alignment_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment