Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4e58a6a0
"git@developer.sourcefind.cn:OpenDAS/sparseconvnet.git" did not exist on "15fd91a0a1a34376105d85ac9f7e5f24dc266394"
Commit
4e58a6a0
authored
Dec 04, 2023
by
Dingquan Yu
Browse files
now use ThreadPoolExecutor
parent
2204bbb2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
17 deletions
+31
-17
openfold/data/data_modules.py
openfold/data/data_modules.py
+10
-6
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+7
-3
openfold/data/parse_msa_files.py
openfold/data/parse_msa_files.py
+14
-8
No files found.
openfold/data/data_modules.py
View file @
4e58a6a0
...
@@ -22,9 +22,9 @@ from openfold.utils.tensor_utils import (
...
@@ -22,9 +22,9 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
tensor_tree_map
,
)
)
def
calculate_elapse
(
start
,
end
):
def
calculate_elapse
(
start
,
end
,
name
):
elapse
=
end
-
start
elapse
=
end
-
start
print
(
f
"
this function
runs
{
round
(
elapse
,
3
)
}
seconds i.e.
{
round
(
elapse
/
60
,
3
)
}
minutes"
)
print
(
f
"
{
name
}
runs
{
round
(
elapse
,
3
)
}
seconds i.e.
{
round
(
elapse
/
60
,
3
)
}
minutes"
)
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -451,7 +451,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -451,7 +451,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
alignment_index
=
None
alignment_index
=
None
import
time
start
=
time
.
time
()
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
ext
=
None
ext
=
None
...
@@ -477,15 +478,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -477,15 +478,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_path
=
path
,
fasta_path
=
path
,
alignment_dir
=
self
.
alignment_dir
alignment_dir
=
self
.
alignment_dir
)
)
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"process_fasta in data_modules"
)
if
self
.
_output_raw
:
if
self
.
_output_raw
:
return
data
return
data
process_start
=
time
.
time
()
# process all_chain_features
# process all_chain_features
data
=
self
.
feature_pipeline
.
process_features
(
data
,
data
=
self
.
feature_pipeline
.
process_features
(
data
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
is_multimer
=
True
)
is_multimer
=
True
)
end
=
time
.
time
()
calculate_elapse
(
process_start
,
end
,
"process_features in data_modules"
)
calculate_elapse
(
start
,
end
,
"dataset get_item in data_modules"
)
# if it's inference mode, only need all_chain_features
# if it's inference mode, only need all_chain_features
data
[
"batch_idx"
]
=
torch
.
tensor
(
data
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
...
...
openfold/data/data_pipeline.py
View file @
4e58a6a0
...
@@ -738,10 +738,14 @@ class DataPipeline:
...
@@ -738,10 +738,14 @@ class DataPipeline:
fp
.
close
()
fp
.
close
()
else
:
else
:
# Now will split the following steps into multiple processes
# Now will split the following steps into multiple processes
import
time
current_directory
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
current_directory
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cmd
=
f
"
{
current_directory
}
/parse_msa_files.py"
cmd
=
f
"
{
current_directory
}
/parse_msa_files.py"
start
=
time
.
time
()
msa_data
=
subprocess
.
run
([
'python'
,
cmd
,
f
"--alignment_dir=
{
alignment_dir
}
"
],
capture_output
=
True
,
text
=
True
)
msa_data
=
subprocess
.
run
([
'python'
,
cmd
,
f
"--alignment_dir=
{
alignment_dir
}
"
],
capture_output
=
True
,
text
=
True
)
msa_data
=
pickle
.
load
((
open
(
msa_data
.
stdout
.
rstrip
(),
'rb'
)))
msa_data
=
pickle
.
load
((
open
(
msa_data
.
stdout
.
lstrip
().
rstrip
(),
'rb'
)))
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"parse_msa_files in data_pipeline"
)
return
msa_data
return
msa_data
def
_parse_template_hit_files
(
def
_parse_template_hit_files
(
...
@@ -823,12 +827,12 @@ class DataPipeline:
...
@@ -823,12 +827,12 @@ class DataPipeline:
alignment_dir
,
input_sequence
,
alignment_index
alignment_dir
,
input_sequence
,
alignment_index
)
)
end
=
time
.
time
()
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"get_msas"
)
calculate_elapse
(
start
,
end
,
"get_msas
in data_pipeline
"
)
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
msas
=
msas
msas
=
msas
)
)
end_main
=
time
.
time
()
end_main
=
time
.
time
()
calculate_elapse
(
start_main
,
end_main
,
"process_msa_feats"
)
calculate_elapse
(
start_main
,
end_main
,
"process_msa_feats
in data_pipeline
"
)
return
msa_features
return
msa_features
# Load and process sequence embedding features
# Load and process sequence embedding features
...
...
openfold/data/parse_msa_files.py
View file @
4e58a6a0
import
os
,
multiprocessing
,
argparse
,
pickle
,
tempfile
import
os
,
multiprocessing
,
argparse
,
pickle
,
tempfile
,
concurrent
import
multiprocessing.pool
# Need to import multiprocessing.pool first otherwise multiprocessing.pool.Pool cannot be called
import
multiprocessing.pool
# Need to import multiprocessing.pool first otherwise multiprocessing.pool.Pool cannot be called
from
openfold.data
import
parsers
from
openfold.data
import
parsers
import
contextlib
from
concurrent.futures
import
ThreadPoolExecutor
def
parse_stockholm_file
(
alignment_dir
:
str
,
stockholm_file
:
str
):
def
parse_stockholm_file
(
alignment_dir
:
str
,
stockholm_file
:
str
):
path
=
os
.
path
.
join
(
alignment_dir
,
stockholm_file
)
path
=
os
.
path
.
join
(
alignment_dir
,
stockholm_file
)
file_name
,
_
=
os
.
path
.
splitext
(
stockholm_file
)
file_name
,
_
=
os
.
path
.
splitext
(
stockholm_file
)
...
@@ -21,15 +22,20 @@ def parse_a3m_file(alignment_dir: str, a3m_file: str):
...
@@ -21,15 +22,20 @@ def parse_a3m_file(alignment_dir: str, a3m_file: str):
return
{
file_name
:
msa
}
return
{
file_name
:
msa
}
def
run_parse_all_msa_files_multiprocessing
(
stockholm_files
:
list
,
a3m_files
:
list
,
alignment_dir
:
str
):
def
run_parse_all_msa_files_multiprocessing
(
stockholm_files
:
list
,
a3m_files
:
list
,
alignment_dir
:
str
):
# Number of workers based on the tasks
msa_results
=
{}
msa_results
=
{}
processes
=
[]
a3m_tasks
=
[(
alignment_dir
,
f
)
for
f
in
a3m_files
]
a3m_tasks
=
[(
alignment_dir
,
f
)
for
f
in
a3m_files
]
sto_tasks
=
[(
alignment_dir
,
f
)
for
f
in
stockholm_files
]
sto_tasks
=
[(
alignment_dir
,
f
)
for
f
in
stockholm_files
]
with
multiprocessing
.
pool
.
Pool
(
len
(
a3m_tasks
)
+
len
(
sto_tasks
))
as
pool
:
with
ThreadPoolExecutor
()
as
executor
:
a3m_results
=
pool
.
starmap_async
(
parse_a3m_file
,
a3m_tasks
).
get
()
a3m_futures
=
{
executor
.
submit
(
parse_a3m_file
,
*
task
):
task
for
task
in
a3m_tasks
}
sto_results
=
pool
.
starmap_async
(
parse_stockholm_file
,
sto_tasks
).
get
()
sto_futures
=
{
executor
.
submit
(
parse_stockholm_file
,
*
task
):
task
for
task
in
sto_tasks
}
for
res
in
[
*
a3m_results
,
*
sto_results
]:
msa_results
.
update
(
res
)
for
future
in
concurrent
.
futures
.
as_completed
(
a3m_futures
|
sto_futures
):
try
:
result
=
future
.
result
()
msa_results
.
update
(
result
)
except
Exception
as
exc
:
print
(
f
'Task generated an exception:
{
exc
}
'
)
return
msa_results
return
msa_results
def
main
():
def
main
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment