Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f861ff39
Unverified
Commit
f861ff39
authored
Dec 11, 2023
by
Christina Floristean
Committed by
GitHub
Dec 11, 2023
Browse files
Merge pull request #376 from dingquanyu/speedup-dataloader
Speed up data loading process
parents
1606ac08
6f26b0ad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
78 additions
and
20 deletions
+78
-20
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+27
-20
openfold/data/tools/parse_msa_files.py
openfold/data/tools/parse_msa_files.py
+51
-0
No files found.
openfold/data/data_pipeline.py
View file @
f861ff39
...
@@ -21,16 +21,17 @@ import dataclasses
...
@@ -21,16 +21,17 @@ import dataclasses
from
multiprocessing
import
cpu_count
from
multiprocessing
import
cpu_count
import
tempfile
import
tempfile
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
import
subprocess
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
pickle
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data.templates
import
get_custom_template_features
,
empty_template_feats
from
openfold.data.templates
import
get_custom_template_features
,
empty_template_feats
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
import
concurrent
from
concurrent.futures
import
ThreadPoolExecutor
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
...
@@ -735,22 +736,11 @@ class DataPipeline:
...
@@ -735,22 +736,11 @@ class DataPipeline:
fp
.
close
()
fp
.
close
()
else
:
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
# Now will split the following steps into multiple processes
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
current_directory
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
filename
,
ext
=
os
.
path
.
splitext
(
f
)
cmd
=
f
"
{
current_directory
}
/tools/parse_msa_files.py"
msa_data
=
subprocess
.
run
([
'python'
,
cmd
,
f
"--alignment_dir=
{
alignment_dir
}
"
],
capture_output
=
True
,
text
=
True
)
if
(
ext
==
".a3m"
):
msa_data
=
pickle
.
load
((
open
(
msa_data
.
stdout
.
lstrip
().
rstrip
(),
'rb'
)))
with
open
(
path
,
"r"
)
as
fp
:
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
elif
(
ext
==
".sto"
and
filename
not
in
[
"uniprot_hits"
,
"hmm_output"
]):
with
open
(
path
,
"r"
)
as
fp
:
msa
=
parsers
.
parse_stockholm
(
fp
.
read
()
)
else
:
continue
msa_data
[
f
]
=
msa
return
msa_data
return
msa_data
...
@@ -826,6 +816,7 @@ class DataPipeline:
...
@@ -826,6 +816,7 @@ class DataPipeline:
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msas
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
alignment_dir
,
input_sequence
,
alignment_index
alignment_dir
,
input_sequence
,
alignment_index
)
)
...
@@ -1216,8 +1207,10 @@ class DataPipelineMultimer:
...
@@ -1216,8 +1207,10 @@ class DataPipelineMultimer:
with
open
(
fasta_path
)
as
f
:
with
open
(
fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
all_chain_features
=
{}
all_chain_features
=
{}
sequence_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
...
@@ -1228,6 +1221,7 @@ class DataPipelineMultimer:
...
@@ -1228,6 +1221,7 @@ class DataPipelineMultimer:
)
)
continue
continue
chain_features
=
self
.
_process_single_chain
(
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
chain_id
=
desc
,
sequence
=
seq
,
sequence
=
seq
,
...
@@ -1236,6 +1230,7 @@ class DataPipelineMultimer:
...
@@ -1236,6 +1230,7 @@ class DataPipelineMultimer:
is_homomer_or_monomer
=
is_homomer_or_monomer
is_homomer_or_monomer
=
is_homomer_or_monomer
)
)
chain_features
=
convert_monomer_features
(
chain_features
=
convert_monomer_features
(
chain_features
,
chain_features
,
chain_id
=
desc
chain_id
=
desc
...
@@ -1243,17 +1238,20 @@ class DataPipelineMultimer:
...
@@ -1243,17 +1238,20 @@ class DataPipelineMultimer:
all_chain_features
[
desc
]
=
chain_features
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
all_chain_features
=
all_chain_features
,
)
)
# Pad MSA to avoid zero-sized extra_msa.
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
return
np_example
def
get_mmcif_features
(
def
get_mmcif_features
(
self
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
self
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
)
->
FeatureDict
:
)
->
FeatureDict
:
...
@@ -1284,18 +1282,21 @@ class DataPipelineMultimer:
...
@@ -1284,18 +1282,21 @@ class DataPipelineMultimer:
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
all_chain_features
=
{}
all_chain_features
=
{}
sequence_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
list
(
mmcif
.
chain_to_seqres
.
values
())))
==
1
is_homomer_or_monomer
=
len
(
set
(
list
(
mmcif
.
chain_to_seqres
.
values
())))
==
1
for
chain_id
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
for
chain_id
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
desc
=
"_"
.
join
([
mmcif
.
file_id
,
chain_id
])
desc
=
"_"
.
join
([
mmcif
.
file_id
,
chain_id
])
if
seq
in
sequence_features
:
if
seq
in
sequence_features
:
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
sequence_features
[
seq
]
sequence_features
[
seq
]
)
)
continue
continue
chain_features
=
self
.
_process_single_chain
(
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
chain_id
=
desc
,
sequence
=
seq
,
sequence
=
seq
,
...
@@ -1304,23 +1305,29 @@ class DataPipelineMultimer:
...
@@ -1304,23 +1305,29 @@ class DataPipelineMultimer:
is_homomer_or_monomer
=
is_homomer_or_monomer
is_homomer_or_monomer
=
is_homomer_or_monomer
)
)
chain_features
=
convert_monomer_features
(
chain_features
=
convert_monomer_features
(
chain_features
,
chain_features
,
chain_id
=
desc
chain_id
=
desc
)
)
mmcif_feats
=
self
.
get_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
self
.
get_mmcif_features
(
mmcif
,
chain_id
)
chain_features
.
update
(
mmcif_feats
)
chain_features
.
update
(
mmcif_feats
)
all_chain_features
[
desc
]
=
chain_features
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
all_chain_features
=
all_chain_features
,
)
)
# Pad MSA to avoid zero-sized extra_msa.
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
return
np_example
\ No newline at end of file
openfold/data/tools/parse_msa_files.py
0 → 100644
View file @
f861ff39
import
os
,
argparse
,
pickle
,
tempfile
,
concurrent
from
openfold.data
import
parsers
from
concurrent.futures
import
ProcessPoolExecutor
def
parse_stockholm_file
(
alignment_dir
:
str
,
stockholm_file
:
str
):
path
=
os
.
path
.
join
(
alignment_dir
,
stockholm_file
)
file_name
,
_
=
os
.
path
.
splitext
(
stockholm_file
)
with
open
(
path
,
"r"
)
as
infile
:
msa
=
parsers
.
parse_stockholm
(
infile
.
read
())
infile
.
close
()
return
{
file_name
:
msa
}
def
parse_a3m_file
(
alignment_dir
:
str
,
a3m_file
:
str
):
path
=
os
.
path
.
join
(
alignment_dir
,
a3m_file
)
file_name
,
_
=
os
.
path
.
splitext
(
a3m_file
)
with
open
(
path
,
"r"
)
as
infile
:
msa
=
parsers
.
parse_a3m
(
infile
.
read
())
infile
.
close
()
return
{
file_name
:
msa
}
def
run_parse_all_msa_files_multiprocessing
(
stockholm_files
:
list
,
a3m_files
:
list
,
alignment_dir
:
str
):
# Number of workers based on the tasks
msa_results
=
{}
a3m_tasks
=
[(
alignment_dir
,
f
)
for
f
in
a3m_files
]
sto_tasks
=
[(
alignment_dir
,
f
)
for
f
in
stockholm_files
]
with
ProcessPoolExecutor
(
max_workers
=
len
(
a3m_tasks
)
+
len
(
sto_tasks
))
as
executor
:
a3m_futures
=
{
executor
.
submit
(
parse_a3m_file
,
*
task
):
task
for
task
in
a3m_tasks
}
sto_futures
=
{
executor
.
submit
(
parse_stockholm_file
,
*
task
):
task
for
task
in
sto_tasks
}
for
future
in
concurrent
.
futures
.
as_completed
(
a3m_futures
|
sto_futures
):
try
:
result
=
future
.
result
()
msa_results
.
update
(
result
)
except
Exception
as
exc
:
print
(
f
'Task generated an exception:
{
exc
}
'
)
return
msa_results
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Process msa files in parallel'
)
parser
.
add_argument
(
'--alignment_dir'
,
type
=
str
,
help
=
'path to alignment dir'
)
args
=
parser
.
parse_args
()
alignment_dir
=
args
.
alignment_dir
stockholm_files
=
[
i
for
i
in
os
.
listdir
(
alignment_dir
)
if
(
i
.
endswith
(
'.sto'
)
and
(
"hmm_output"
not
in
i
))]
a3m_files
=
[
i
for
i
in
os
.
listdir
(
alignment_dir
)
if
i
.
endswith
(
'.a3m'
)]
msa_data
=
run_parse_all_msa_files_multiprocessing
(
stockholm_files
,
a3m_files
,
alignment_dir
)
with
tempfile
.
NamedTemporaryFile
(
'wb'
,
suffix
=
'.pkl'
,
delete
=
False
)
as
outfile
:
pickle
.
dump
(
msa_data
,
outfile
)
print
(
outfile
.
name
)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment