Commit bcabb8e3 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix bug from deepspeed upgrade, do not use uniprot_hits in msa feats

parent f861ff39
......@@ -2,6 +2,7 @@ import os, argparse, pickle, tempfile, concurrent
from openfold.data import parsers
from concurrent.futures import ProcessPoolExecutor
def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file)
......@@ -10,6 +11,7 @@ def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
infile.close()
return {file_name: msa}
def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file)
......@@ -18,6 +20,7 @@ def parse_a3m_file(alignment_dir: str, a3m_file: str):
infile.close()
return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
# Number of workers based on the tasks
msa_results={}
......@@ -35,17 +38,20 @@ def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: li
print(f'Task generated an exception: {exc}')
return msa_results
def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
args = parser.parse_args()
alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))]
stockholm_files = [i for i in os.listdir(alignment_dir)
if all([i.endswith('.sto'), "hmm_output" not in i, "uniprot_hits" not in i])]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile)
print(outfile.name)
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -189,7 +189,7 @@ class Linear(nn.Linear):
d = input.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
deepspeed.comm.comm.is_initialized()
)
if self.precision is not None:
with torch.cuda.amp.autocast(enabled=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment