"examples/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "f7a60cba6a078f283f575784e0110b25dc397b7f"
Commit bcabb8e3 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix bug from deepspeed upgrade, do not use uniprot_hits in msa feats

parent f861ff39
...@@ -2,6 +2,7 @@ import os, argparse, pickle, tempfile, concurrent ...@@ -2,6 +2,7 @@ import os, argparse, pickle, tempfile, concurrent
from openfold.data import parsers from openfold.data import parsers
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
def parse_stockholm_file(alignment_dir: str, stockholm_file: str): def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file) path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file) file_name,_ = os.path.splitext(stockholm_file)
...@@ -10,6 +11,7 @@ def parse_stockholm_file(alignment_dir: str, stockholm_file: str): ...@@ -10,6 +11,7 @@ def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
infile.close() infile.close()
return {file_name: msa} return {file_name: msa}
def parse_a3m_file(alignment_dir: str, a3m_file: str): def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file) path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file) file_name,_ = os.path.splitext(a3m_file)
...@@ -18,6 +20,7 @@ def parse_a3m_file(alignment_dir: str, a3m_file: str): ...@@ -18,6 +20,7 @@ def parse_a3m_file(alignment_dir: str, a3m_file: str):
infile.close() infile.close()
return {file_name: msa} return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str): def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
# Number of workers based on the tasks # Number of workers based on the tasks
msa_results={} msa_results={}
...@@ -35,17 +38,20 @@ def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: li ...@@ -35,17 +38,20 @@ def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: li
print(f'Task generated an exception: {exc}') print(f'Task generated an exception: {exc}')
return msa_results return msa_results
def main(): def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel') parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir') parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
args = parser.parse_args() args = parser.parse_args()
alignment_dir = args.alignment_dir alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))] stockholm_files = [i for i in os.listdir(alignment_dir)
if all([i.endswith('.sto'), "hmm_output" not in i, "uniprot_hits" not in i])]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')] a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir) msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile: with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile) pickle.dump(msa_data, outfile)
print(outfile.name) print(outfile.name)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -189,7 +189,7 @@ class Linear(nn.Linear): ...@@ -189,7 +189,7 @@ class Linear(nn.Linear):
d = input.dtype d = input.dtype
deepspeed_is_initialized = ( deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed_is_installed and
deepspeed.utils.is_initialized() deepspeed.comm.comm.is_initialized()
) )
if self.precision is not None: if self.precision is not None:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment