"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "2fa93c6946e84a41156f37001df5dd2e56e4f7b5"
Commit 81ae777d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make template parsing optional

parent ea484b71
...@@ -339,12 +339,16 @@ class DataPipeline: ...@@ -339,12 +339,16 @@ class DataPipeline:
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
templates_result = self.template_featurizer.get_templates( if(len(hits_cat) == 0):
query_sequence=input_sequence, template_features = {}
query_pdb_code=None, else:
query_release_date=None, templates_result = self.template_featurizer.get_templates(
hits=hits_cat, query_sequence=input_sequence,
) query_pdb_code=None,
query_release_date=None,
hits=hits_cat,
)
template_features = templates_result.features
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
...@@ -357,7 +361,7 @@ class DataPipeline: ...@@ -357,7 +361,7 @@ class DataPipeline:
return { return {
**sequence_features, **sequence_features,
**msa_features, **msa_features,
**templates_result.features **template_features
} }
def process_mmcif( def process_mmcif(
...@@ -384,17 +388,20 @@ class DataPipeline: ...@@ -384,17 +388,20 @@ class DataPipeline:
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
print(len(hits_cat)) if(len(hits_cat) == 0):
templates_result = self.template_featurizer.get_templates( template_features = {}
query_sequence=input_sequence, else:
query_pdb_code=None, templates_result = self.template_featurizer.get_templates(
query_release_date=to_date(mmcif.header["release_date"]), query_sequence=input_sequence,
hits=hits_cat, query_pdb_code=None,
) query_release_date=to_date(mmcif.header["release_date"]),
hits=hits_cat,
)
template_features = templates_result.features
msa_features = self._process_msa_feats(alignment_dir) msa_features = self._process_msa_feats(alignment_dir)
return {**mmcif_feats, **templates_result.features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
def process_pdb( def process_pdb(
self, self,
...@@ -413,13 +420,17 @@ class DataPipeline: ...@@ -413,13 +420,17 @@ class DataPipeline:
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
templates_result = self.template_featurizer.get_templates( if(len(hits_cat) == 0):
query_sequence=input_sequence, template_features = {}
query_pdb_code=None, else:
query_release_date=None, templates_result = self.template_featurizer.get_templates(
hits=hits_cat, query_sequence=input_sequence,
) query_pdb_code=None,
query_release_date=None,
hits=hits_cat,
)
template_features = templates_result.features
msa_features = self._process_msa_feats(alignment_dir) msa_features = self._process_msa_feats(alignment_dir)
return {**pdb_feats, **templates_result.features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment