Commit d71d37ff authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix empty template feature bug

parent 81ae777d
......@@ -189,7 +189,6 @@ config = mlc.ConfigDict(
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"num_ensemble": 1,
"crop": False,
"crop_size": None,
"supervised": False,
......@@ -202,7 +201,6 @@ config = mlc.ConfigDict(
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"num_ensemble": 1,
"crop": False,
"crop_size": None,
"supervised": True,
......@@ -215,7 +213,6 @@ config = mlc.ConfigDict(
"max_msa_clusters": 128,
"max_template_hits": 20,
"max_templates": 4,
"num_ensemble": 1,
"crop": True,
"crop_size": 256,
"supervised": True,
......
......@@ -21,12 +21,21 @@ import numpy as np
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_sequence_features(
sequence: str, description: str, num_res: int
......@@ -340,7 +349,7 @@ class DataPipeline:
hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0):
template_features = {}
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
......@@ -389,7 +398,7 @@ class DataPipeline:
hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0):
template_features = {}
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
......@@ -399,6 +408,12 @@ class DataPipeline:
)
template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
msa_features = self._process_msa_feats(alignment_dir)
return {**mmcif_feats, **template_features, **msa_features}
......@@ -415,13 +430,14 @@ class DataPipeline:
pdb_str = pdb_path
protein_object = protein.from_pdb_string(pdb_str)
input_sequence = protein_object.aatype
pdb_feats = make_pdb_features(protein_object)
hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0):
template_features = {}
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
......
......@@ -85,17 +85,18 @@ def make_all_atom_aatype(protein):
def fix_templates_aatype(protein):
# Map one-hot to indices
num_templates = protein["template_aatype"].shape[0]
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
num_templates, -1
)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
if(num_templates > 0):
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
num_templates, -1
)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
return protein
......@@ -169,10 +170,13 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra):
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0]
shuffled = torch.randperm(num_seq - 1) + 1
g = torch.Generator(device=protein["msa"].device)
if seed is not None:
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split(
......@@ -1095,18 +1099,22 @@ def random_crop_to_size(
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
seq_length = protein["seq_length"]
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = protein["aatype"].new_zeros((1,))
num_templates = 0
num_res_crop_size = min(int(seq_length), crop_size)
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
num_res_crop_size = min(int(seq_length), crop_size)
def _randint(lower, upper):
return int(torch.randint(
......
......@@ -86,8 +86,16 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms.sample_msa(max_msa_clusters, keep_extra=True)
data_transforms.sample_msa(
max_msa_clusters,
keep_extra=True,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
......@@ -122,7 +130,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
seed=ensemble_seed,
seed=ensemble_seed + 1,
)
)
transforms.append(
......@@ -159,21 +167,23 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d["ensemble_index"] = i
return fn(d)
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
num_ensemble = mode_cfg.num_ensemble
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= num_recycling + 1
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors
......
......@@ -241,31 +241,35 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
chunk_size,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]], dim=-3
template_mask = feats["template_mask"]
if(torch.any(template_mask)):
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
chunk_size,
)
# [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment