Commit d71d37ff authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix empty template feature bug

parent 81ae777d
...@@ -189,7 +189,6 @@ config = mlc.ConfigDict( ...@@ -189,7 +189,6 @@ config = mlc.ConfigDict(
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"num_ensemble": 1,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"supervised": False, "supervised": False,
...@@ -202,7 +201,6 @@ config = mlc.ConfigDict( ...@@ -202,7 +201,6 @@ config = mlc.ConfigDict(
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"num_ensemble": 1,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"supervised": True, "supervised": True,
...@@ -215,7 +213,6 @@ config = mlc.ConfigDict( ...@@ -215,7 +213,6 @@ config = mlc.ConfigDict(
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_template_hits": 20, "max_template_hits": 20,
"max_templates": 4, "max_templates": 4,
"num_ensemble": 1,
"crop": True, "crop": True,
"crop_size": 256, "crop_size": 256,
"supervised": True, "supervised": True,
......
...@@ -21,12 +21,21 @@ import numpy as np ...@@ -21,12 +21,21 @@ import numpy as np
from openfold.data import templates, parsers, mmcif_parsing from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.tools import jackhmmer, hhblits, hhsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_sequence_features( def make_sequence_features(
sequence: str, description: str, num_res: int sequence: str, description: str, num_res: int
...@@ -340,7 +349,7 @@ class DataPipeline: ...@@ -340,7 +349,7 @@ class DataPipeline:
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0): if(len(hits_cat) == 0):
template_features = {} template_features = empty_template_feats(len(input_sequence))
else: else:
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
...@@ -389,7 +398,7 @@ class DataPipeline: ...@@ -389,7 +398,7 @@ class DataPipeline:
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0): if(len(hits_cat) == 0):
template_features = {} template_features = empty_template_feats(len(input_sequence))
else: else:
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
...@@ -399,6 +408,12 @@ class DataPipeline: ...@@ -399,6 +408,12 @@ class DataPipeline:
) )
template_features = templates_result.features template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
msa_features = self._process_msa_feats(alignment_dir) msa_features = self._process_msa_feats(alignment_dir)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
...@@ -415,13 +430,14 @@ class DataPipeline: ...@@ -415,13 +430,14 @@ class DataPipeline:
pdb_str = pdb_path pdb_str = pdb_path
protein_object = protein.from_pdb_string(pdb_str) protein_object = protein.from_pdb_string(pdb_str)
input_sequence = protein_object.aatype
pdb_feats = make_pdb_features(protein_object) pdb_feats = make_pdb_features(protein_object)
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0): if(len(hits_cat) == 0):
template_features = {} template_features = empty_template_feats(len(input_sequence))
else: else:
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
......
...@@ -85,17 +85,18 @@ def make_all_atom_aatype(protein): ...@@ -85,17 +85,18 @@ def make_all_atom_aatype(protein):
def fix_templates_aatype(protein): def fix_templates_aatype(protein):
# Map one-hot to indices # Map one-hot to indices
num_templates = protein["template_aatype"].shape[0] num_templates = protein["template_aatype"].shape[0]
protein["template_aatype"] = torch.argmax( if(num_templates > 0):
protein["template_aatype"], dim=-1 protein["template_aatype"] = torch.argmax(
) protein["template_aatype"], dim=-1
# Map hhsearch-aatype to our aatype. )
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE # Map hhsearch-aatype to our aatype.
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand( new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
num_templates, -1 new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
) num_templates, -1
protein["template_aatype"] = torch.gather( )
new_order, 1, index=protein["template_aatype"] protein["template_aatype"] = torch.gather(
) new_order, 1, index=protein["template_aatype"]
)
return protein return protein
...@@ -169,10 +170,13 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -169,10 +170,13 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
return protein return protein
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra): def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0] num_seq = protein["msa"].shape[0]
shuffled = torch.randperm(num_seq - 1) + 1 g = torch.Generator(device=protein["msa"].device)
if seed is not None:
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0) index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
num_sel = min(max_seq, num_seq) num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split( sel_seq, not_sel_seq = torch.split(
...@@ -1095,18 +1099,22 @@ def random_crop_to_size( ...@@ -1095,18 +1099,22 @@ def random_crop_to_size(
seed=None, seed=None,
): ):
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" """Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
seq_length = protein["seq_length"] seq_length = protein["seq_length"]
if "template_mask" in protein: if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1] num_templates = protein["template_mask"].shape[-1]
else: else:
num_templates = protein["aatype"].new_zeros((1,)) num_templates = 0
num_res_crop_size = min(int(seq_length), crop_size) # No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
# We want each ensemble to be cropped the same way num_res_crop_size = min(int(seq_length), crop_size)
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
def _randint(lower, upper): def _randint(lower, upper):
return int(torch.randint( return int(torch.randint(
......
...@@ -86,8 +86,16 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -86,8 +86,16 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = common_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append( transforms.append(
data_transforms.sample_msa(max_msa_clusters, keep_extra=True) data_transforms.sample_msa(
max_msa_clusters,
keep_extra=True,
seed=msa_seed,
)
) )
if "masked_msa" in common_cfg: if "masked_msa" in common_cfg:
...@@ -122,7 +130,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -122,7 +130,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
mode_cfg.max_templates, mode_cfg.max_templates,
crop_feats, crop_feats,
mode_cfg.subsample_templates, mode_cfg.subsample_templates,
seed=ensemble_seed, seed=ensemble_seed + 1,
) )
) )
transforms.append( transforms.append(
...@@ -159,21 +167,23 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg): ...@@ -159,21 +167,23 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d["ensemble_index"] = i d["ensemble_index"] = i
return fn(d) return fn(d)
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors) no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
num_ensemble = mode_cfg.num_ensemble nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors): if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"]) num_recycling = int(tensors["no_recycling_iters"])
else: else:
num_recycling = common_cfg.max_recycling_iters num_recycling = common_cfg.max_recycling_iters
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= num_recycling + 1
tensors = map_fn( tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
) )
return tensors return tensors
......
...@@ -241,31 +241,35 @@ class AlphaFold(nn.Module): ...@@ -241,31 +241,35 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_feats = { template_mask = feats["template_mask"]
k: v for k, v in feats.items() if k.startswith("template_") if(torch.any(template_mask)):
} template_feats = {
template_embeds = self.embed_templates( k: v for k, v in feats.items() if k.startswith("template_")
template_feats, }
z, template_embeds = self.embed_templates(
pair_mask, template_feats,
no_batch_dims, z,
chunk_size, pair_mask,
) no_batch_dims,
chunk_size,
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]], dim=-3
) )
# [*, S, N] # [*, N, N, C_z]
torsion_angles_mask = feats["template_torsion_angles_mask"] z = z + template_embeds["template_pair_embedding"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2 if self.config.template.embed_angles:
) # [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment