Commit 6133ea95 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix .core parsing bug

parent 79f9f03d
...@@ -45,7 +45,7 @@ def make_template_features( ...@@ -45,7 +45,7 @@ def make_template_features(
query_release_date: Optional[str] = None, query_release_date: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0): if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence)) template_features = empty_template_feats(len(input_sequence))
else: else:
templates_result = template_featurizer.get_templates( templates_result = template_featurizer.get_templates(
...@@ -130,7 +130,8 @@ def _aatype_to_str_sequence(aatype): ...@@ -130,7 +130,8 @@ def _aatype_to_str_sequence(aatype):
def make_protein_features( def make_protein_features(
protein_object: protein.Protein, protein_object: protein.Protein,
description: str, description: str,
_is_distillation: bool = False,
) -> FeatureDict: ) -> FeatureDict:
pdb_feats = {} pdb_feats = {}
aatype = protein_object.aatype aatype = protein_object.aatype
...@@ -150,7 +151,9 @@ def make_protein_features( ...@@ -150,7 +151,9 @@ def make_protein_features(
pdb_feats["all_atom_mask"] = all_atom_mask pdb_feats["all_atom_mask"] = all_atom_mask
pdb_feats["resolution"] = np.array([0.]).astype(np.float32) pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(1.).astype(np.float32) pdb_feats["is_distillation"] = np.array(
1. if _is_distillation else 0.
).astype(np.float32)
return pdb_feats return pdb_feats
...@@ -160,7 +163,10 @@ def make_pdb_features( ...@@ -160,7 +163,10 @@ def make_pdb_features(
description: str, description: str,
confidence_threshold: float = 0.5, confidence_threshold: float = 0.5,
) -> FeatureDict: ) -> FeatureDict:
pdb_feats = make_protein_features(protein_object, description) """ Use only for distillation set PDBs """
pdb_feats = make_protein_features(
protein_object, description, _is_distillation=True
)
high_confidence = protein_object.b_factors > confidence_threshold high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1) high_confidence = np.any(high_confidence, axis=-1)
...@@ -312,7 +318,7 @@ class DataPipeline: ...@@ -312,7 +318,7 @@ class DataPipeline:
"""Assembles input features.""" """Assembles input features."""
def __init__( def __init__(
self, self,
template_featurizer: templates.TemplateHitFeaturizer, template_featurizer: Optional[templates.TemplateHitFeaturizer],
): ):
self.template_featurizer = template_featurizer self.template_featurizer = template_featurizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment