"vscode:/vscode.git/clone" did not exist on "e4fb2aa4ea2a6347df67473bb2f02169510a1cab"
Commit 407d9924 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add empty template handling to inference pipeline

parent 8f7e90d6
...@@ -359,6 +359,11 @@ class DataPipeline: ...@@ -359,6 +359,11 @@ class DataPipeline:
) )
template_features = templates_result.features template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
description=input_description, description=input_description,
...@@ -397,6 +402,7 @@ class DataPipeline: ...@@ -397,6 +402,7 @@ class DataPipeline:
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
print(len(hits_cat))
if(len(hits_cat) == 0): if(len(hits_cat) == 0):
template_features = empty_template_feats(len(input_sequence)) template_features = empty_template_feats(len(input_sequence))
else: else:
...@@ -447,6 +453,11 @@ class DataPipeline: ...@@ -447,6 +453,11 @@ class DataPipeline:
) )
template_features = templates_result.features template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
msa_features = self._process_msa_feats(alignment_dir) msa_features = self._process_msa_feats(alignment_dir)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
...@@ -79,7 +79,6 @@ def main(args): ...@@ -79,7 +79,6 @@ def main(args):
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd
) )
output_dir_base = args.output_dir output_dir_base = args.output_dir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment