Commit e8e0b66f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make minor formatting edits

parent f8c81df4
...@@ -62,7 +62,7 @@ class DataPipeline: ...@@ -62,7 +62,7 @@ class DataPipeline:
"""Runs the alignment tools and assembles the input features.""" """Runs the alignment tools and assembles the input features."""
def __init__(self, def __init__(self,
jackhammer_binary_path: str, jackhmmer_binary_path: str,
hhblits_binary_path: str, hhblits_binary_path: str,
hhsearch_binary_path: str, hhsearch_binary_path: str,
uniref90_database_path: str, uniref90_database_path: str,
...@@ -79,12 +79,12 @@ class DataPipeline: ...@@ -79,12 +79,12 @@ class DataPipeline:
"""Constructs a feature dict for a given FASTA file.""" """Constructs a feature dict for a given FASTA file."""
self._use_small_bfd = use_small_bfd self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhammer_binary_path, binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path database_path=uniref90_database_path
) )
if use_small_bfd: if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhammer_binary_path, binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path database_path=small_bfd_database_path
) )
else: else:
...@@ -93,7 +93,7 @@ class DataPipeline: ...@@ -93,7 +93,7 @@ class DataPipeline:
databases=[bfd_database_path, uniclust30_database_path] databases=[bfd_database_path, uniclust30_database_path]
) )
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhammer_binary_path, binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path database_path=mgnify_database_path
) )
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.hhsearch_pdb70_runner = hhsearch.HHSearch(
......
...@@ -198,7 +198,7 @@ class T: ...@@ -198,7 +198,7 @@ class T:
denom = torch.sqrt(sum((c * c for c in e0)) + eps) denom = torch.sqrt(sum((c * c for c in e0)) + eps)
e0 = [c / denom for c in e0] e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c1 - c2 * dot for c1, c2 in zip(e1, e0)] e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps) denom = torch.sqrt(sum((c * c for c in e1)) + eps)
e1 = [c / denom for c in e1] e1 = [c / denom for c in e1]
e2 = [ e2 = [
......
...@@ -63,10 +63,11 @@ def main(args): ...@@ -63,10 +63,11 @@ def main(args):
max_hits=MAX_TEMPLATE_HITS, max_hits=MAX_TEMPLATE_HITS,
kalign_binary_path=args.kalign_binary_path, kalign_binary_path=args.kalign_binary_path,
release_dates_path=None, release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path) obsolete_pdbs_path=args.obsolete_pdbs_path
)
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
jackhammer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path, hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path, uniref90_database_path=args.uniref90_database_path,
...@@ -94,16 +95,11 @@ def main(args): ...@@ -94,16 +95,11 @@ def main(args):
print("Collecting data...") print("Collecting data...")
feature_dict = data_processor.process( feature_dict = data_processor.process(
input_fasta_path=args.fasta_path, msa_output_dir=msa_output_dir) input_fasta_path=args.fasta_path, msa_output_dir=msa_output_dir)
# Output the features
features_output_path = os.path.join(output_dir_base, 'features.pkl')
with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)
print("Generating features...") print("Generating features...")
processed_feature_dict = feature_processor.process_features(feature_dict, random_seed) processed_feature_dict = feature_processor.process_features(
feature_dict, random_seed
with open(os.path.join(output_dir_base, 'processed_feats.pkl'), 'wb') as f: )
pickle.dump(processed_feature_dict, f, protocol=4)
print("Executing model...") print("Executing model...")
batch = processed_feature_dict batch = processed_feature_dict
...@@ -209,7 +205,8 @@ if __name__ == "__main__": ...@@ -209,7 +205,8 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign' '--kalign_binary_path', type=str, default='/usr/bin/kalign'
) )
parser.add_argument('--uniref90_database_path', type=str, default=None, required=True parser.add_argument(
'--uniref90_database_path', type=str, default=None, required=True
) )
parser.add_argument( parser.add_argument(
'--mgnify_database_path', type=str, default=None, required=True '--mgnify_database_path', type=str, default=None, required=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment