Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e72e4e62
Commit
e72e4e62
authored
Dec 05, 2023
by
Dingquan Yu
Browse files
remove debugging statement
parent
53c03a6a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
29 deletions
+29
-29
openfold/data/data_modules.py
openfold/data/data_modules.py
+5
-11
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+24
-18
No files found.
openfold/data/data_modules.py
View file @
e72e4e62
...
@@ -22,9 +22,6 @@ from openfold.utils.tensor_utils import (
...
@@ -22,9 +22,6 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
tensor_tree_map
,
)
)
def
calculate_elapse
(
start
,
end
,
name
):
elapse
=
end
-
start
print
(
f
"
{
name
}
runs
{
round
(
elapse
,
3
)
}
seconds i.e.
{
round
(
elapse
/
60
,
3
)
}
minutes"
)
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -451,9 +448,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -451,9 +448,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
alignment_index
=
None
alignment_index
=
None
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
import
time
import
time
start
=
time
.
time
()
start
=
time
.
time
()
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
ext
=
None
ext
=
None
for
e
in
self
.
supported_exts
:
for
e
in
self
.
supported_exts
:
...
@@ -478,18 +476,14 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -478,18 +476,14 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_path
=
path
,
fasta_path
=
path
,
alignment_dir
=
self
.
alignment_dir
alignment_dir
=
self
.
alignment_dir
)
)
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"process_fasta in data_modules"
)
if
self
.
_output_raw
:
if
self
.
_output_raw
:
return
data
return
data
process_start
=
time
.
time
()
# process all_chain_features
# process all_chain_features
data
=
self
.
feature_pipeline
.
process_features
(
data
,
data
=
self
.
feature_pipeline
.
process_features
(
data
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
is_multimer
=
True
)
is_multimer
=
True
)
end
=
time
.
time
()
calculate_elapse
(
process_start
,
end
,
"process_features in data_modules"
)
calculate_elapse
(
start
,
end
,
"dataset get_item in data_modules"
)
# if it's inference mode, only need all_chain_features
# if it's inference mode, only need all_chain_features
data
[
"batch_idx"
]
=
torch
.
tensor
(
data
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
...
...
openfold/data/data_pipeline.py
View file @
e72e4e62
...
@@ -28,16 +28,14 @@ import pickle
...
@@ -28,16 +28,14 @@ import pickle
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data.templates
import
get_custom_template_features
,
empty_template_feats
from
openfold.data.templates
import
get_custom_template_features
,
empty_template_feats
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools.utils
import
to_date
,
NonDaemonicProcess
,
NonDaemonicProcessPool
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
import
concurrent
from
concurrent.futures
import
ThreadPoolExecutor
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
def
calculate_elapse
(
start
,
end
,
name
):
elapse
=
end
-
start
print
(
f
"
{
name
}
runs
{
round
(
elapse
,
3
)
}
seconds i.e.
{
round
(
elapse
/
60
,
3
)
}
minutes"
)
def
make_template_features
(
def
make_template_features
(
input_sequence
:
str
,
input_sequence
:
str
,
...
@@ -738,14 +736,11 @@ class DataPipeline:
...
@@ -738,14 +736,11 @@ class DataPipeline:
fp
.
close
()
fp
.
close
()
else
:
else
:
# Now will split the following steps into multiple processes
# Now will split the following steps into multiple processes
import
time
current_directory
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
current_directory
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cmd
=
f
"
{
current_directory
}
/parse_msa_files.py"
cmd
=
f
"
{
current_directory
}
/tools/parse_msa_files.py"
start
=
time
.
time
()
msa_data
=
subprocess
.
run
([
'python'
,
cmd
,
f
"--alignment_dir=
{
alignment_dir
}
"
],
capture_output
=
True
,
text
=
True
)
msa_data
=
subprocess
.
run
([
'python'
,
cmd
,
f
"--alignment_dir=
{
alignment_dir
}
"
],
capture_output
=
True
,
text
=
True
)
msa_data
=
pickle
.
load
((
open
(
msa_data
.
stdout
.
lstrip
().
rstrip
(),
'rb'
)))
msa_data
=
pickle
.
load
((
open
(
msa_data
.
stdout
.
lstrip
().
rstrip
(),
'rb'
)))
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"parse_msa_files in data_pipeline"
)
return
msa_data
return
msa_data
def
_parse_template_hit_files
(
def
_parse_template_hit_files
(
...
@@ -820,19 +815,14 @@ class DataPipeline:
...
@@ -820,19 +815,14 @@ class DataPipeline:
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
import
time
start_main
=
time
.
time
()
start
=
time
.
time
()
msas
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
alignment_dir
,
input_sequence
,
alignment_index
alignment_dir
,
input_sequence
,
alignment_index
)
)
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"get_msas in data_pipeline"
)
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
msas
=
msas
msas
=
msas
)
)
end_main
=
time
.
time
()
calculate_elapse
(
start_main
,
end_main
,
"process_msa_feats in data_pipeline"
)
return
msa_features
return
msa_features
# Load and process sequence embedding features
# Load and process sequence embedding features
...
@@ -1216,8 +1206,10 @@ class DataPipelineMultimer:
...
@@ -1216,8 +1206,10 @@ class DataPipelineMultimer:
with
open
(
fasta_path
)
as
f
:
with
open
(
fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
all_chain_features
=
{}
all_chain_features
=
{}
sequence_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
...
@@ -1228,6 +1220,7 @@ class DataPipelineMultimer:
...
@@ -1228,6 +1220,7 @@ class DataPipelineMultimer:
)
)
continue
continue
chain_features
=
self
.
_process_single_chain
(
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
chain_id
=
desc
,
sequence
=
seq
,
sequence
=
seq
,
...
@@ -1236,6 +1229,7 @@ class DataPipelineMultimer:
...
@@ -1236,6 +1229,7 @@ class DataPipelineMultimer:
is_homomer_or_monomer
=
is_homomer_or_monomer
is_homomer_or_monomer
=
is_homomer_or_monomer
)
)
chain_features
=
convert_monomer_features
(
chain_features
=
convert_monomer_features
(
chain_features
,
chain_features
,
chain_id
=
desc
chain_id
=
desc
...
@@ -1243,12 +1237,15 @@ class DataPipelineMultimer:
...
@@ -1243,12 +1237,15 @@ class DataPipelineMultimer:
all_chain_features
[
desc
]
=
chain_features
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
all_chain_features
=
all_chain_features
,
)
)
# Pad MSA to avoid zero-sized extra_msa.
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
np_example
=
pad_msa
(
np_example
,
512
)
...
@@ -1284,18 +1281,21 @@ class DataPipelineMultimer:
...
@@ -1284,18 +1281,21 @@ class DataPipelineMultimer:
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
all_chain_features
=
{}
all_chain_features
=
{}
sequence_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
list
(
mmcif
.
chain_to_seqres
.
values
())))
==
1
is_homomer_or_monomer
=
len
(
set
(
list
(
mmcif
.
chain_to_seqres
.
values
())))
==
1
for
chain_id
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
for
chain_id
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
desc
=
"_"
.
join
([
mmcif
.
file_id
,
chain_id
])
desc
=
"_"
.
join
([
mmcif
.
file_id
,
chain_id
])
if
seq
in
sequence_features
:
if
seq
in
sequence_features
:
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
sequence_features
[
seq
]
sequence_features
[
seq
]
)
)
continue
continue
chain_features
=
self
.
_process_single_chain
(
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
chain_id
=
desc
,
sequence
=
seq
,
sequence
=
seq
,
...
@@ -1304,23 +1304,29 @@ class DataPipelineMultimer:
...
@@ -1304,23 +1304,29 @@ class DataPipelineMultimer:
is_homomer_or_monomer
=
is_homomer_or_monomer
is_homomer_or_monomer
=
is_homomer_or_monomer
)
)
chain_features
=
convert_monomer_features
(
chain_features
=
convert_monomer_features
(
chain_features
,
chain_features
,
chain_id
=
desc
chain_id
=
desc
)
)
mmcif_feats
=
self
.
get_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
self
.
get_mmcif_features
(
mmcif
,
chain_id
)
chain_features
.
update
(
mmcif_feats
)
chain_features
.
update
(
mmcif_feats
)
all_chain_features
[
desc
]
=
chain_features
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
all_chain_features
=
all_chain_features
,
)
)
# Pad MSA to avoid zero-sized extra_msa.
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
return
np_example
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment