"unicore/git@developer.sourcefind.cn:OpenDAS/Uni-Core.git" did not exist on "689e0b248dd77a2ee32930d41f89e524c5d833f6"
Commit 6275091c authored by Christina Floristean's avatar Christina Floristean
Browse files

Fixed learning rate scheduler issue, returned to original msa file parsing

parent bc075004
...@@ -21,14 +21,11 @@ import dataclasses ...@@ -21,14 +21,11 @@ import dataclasses
from multiprocessing import cpu_count from multiprocessing import cpu_count
import tempfile import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np import numpy as np
import torch import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
FeatureDict = MutableMapping[str, np.ndarray] FeatureDict = MutableMapping[str, np.ndarray]
...@@ -704,10 +701,10 @@ class DataPipeline: ...@@ -704,10 +701,10 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
if(alignment_index is not None): if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
...@@ -718,14 +715,14 @@ class DataPipeline: ...@@ -718,14 +715,14 @@ class DataPipeline:
for (name, start, size) in alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name) filename, ext = os.path.splitext(name)
if(ext == ".a3m"): if ext == ".a3m":
msa = parsers.parse_a3m( msa = parsers.parse_a3m(
read_msa(start, size) read_msa(start, size)
) )
# The "hmm_output" exception is a crude way to exclude # The "hmm_output" exception is a crude way to exclude
# multimer template hits. # multimer template hits.
# Multimer "uniprot_hits" processed separately. # Multimer "uniprot_hits" processed separately.
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]): elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
msa = parsers.parse_stockholm(read_msa(start, size)) msa = parsers.parse_stockholm(read_msa(start, size))
else: else:
continue continue
...@@ -734,13 +731,22 @@ class DataPipeline: ...@@ -734,13 +731,22 @@ class DataPipeline:
fp.close() fp.close()
else: else:
# Now will split the following steps into multiple processes for f in os.listdir(alignment_dir):
current_directory = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(alignment_dir, f)
cmd = f"{current_directory}/tools/parse_msa_files.py" filename, ext = os.path.splitext(f)
msa_data_path = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data_path = msa_data_path.stdout.lstrip().rstrip() if ext == ".a3m":
msa_data = pickle.load((open(msa_data_path,'rb'))) with open(path, "r") as fp:
os.remove(msa_data_path) msa = parsers.parse_a3m(fp.read())
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
return msa_data return msa_data
......
...@@ -63,7 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -63,7 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag), local_alignment_dir = os.path.join(alignment_dir, tag)
if args.use_precomputed_alignments is None: if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
......
...@@ -234,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -234,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler = AlphaFoldLRScheduler( lr_scheduler = AlphaFoldLRScheduler(
optimizer, optimizer,
last_epoch=self.last_lr_step
) )
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment