Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
aec12764
Commit
aec12764
authored
Nov 30, 2023
by
Dingquan Yu
Browse files
added timing steps
parent
cdeb8d1b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
12 deletions
+38
-12
openfold/data/data_modules.py
openfold/data/data_modules.py
+10
-7
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+21
-1
openfold/data/parsers.py
openfold/data/parsers.py
+3
-3
train_openfold.py
train_openfold.py
+4
-1
No files found.
openfold/data/data_modules.py
View file @
aec12764
...
@@ -22,6 +22,9 @@ from openfold.utils.tensor_utils import (
...
@@ -22,6 +22,9 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
tensor_tree_map
,
)
)
def
calculate_elapse
(
start
,
end
):
elapse
=
end
-
start
print
(
f
"this function runs
{
round
(
elapse
,
3
)
}
seconds i.e.
{
round
(
elapse
/
60
,
3
)
}
minutes"
)
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -195,7 +198,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -195,7 +198,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_index
=
alignment_index
,
alignment_index
=
alignment_index
,
seqemb_mode
=
self
.
config
.
seqemb_mode
.
enabled
seqemb_mode
=
self
.
config
.
seqemb_mode
.
enabled
)
)
return
data
return
data
def
chain_id_to_idx
(
self
,
chain_id
):
def
chain_id_to_idx
(
self
,
chain_id
):
...
@@ -423,24 +425,25 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -423,24 +425,25 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
with
open
(
path
,
'r'
)
as
f
:
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
mmcif_string
=
f
.
read
()
import
time
mmcif_object
=
mmcif_parsing
.
parse
(
mmcif_object
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
)
# Crash if an error is encountered. Any parsing errors should have
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
# been dealt with at the alignment stage.
if
mmcif_object
.
mmcif_object
is
None
:
if
mmcif_object
.
mmcif_object
is
None
:
raise
list
(
mmcif_object
.
errors
.
values
())[
0
]
raise
list
(
mmcif_object
.
errors
.
values
())[
0
]
mmcif_object
=
mmcif_object
.
mmcif_object
mmcif_object
=
mmcif_object
.
mmcif_object
# print(f" ###### line 442 started mmcif_processing")
# start = time.time()
data
=
self
.
data_pipeline
.
process_mmcif
(
data
=
self
.
data_pipeline
.
process_mmcif
(
mmcif
=
mmcif_object
,
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
alignment_index
=
alignment_index
alignment_index
=
alignment_index
)
)
# end = time.time()
# calculate_elapse(start , end)s
return
data
return
data
def
mmcif_id_to_idx
(
self
,
mmcif_id
):
def
mmcif_id_to_idx
(
self
,
mmcif_id
):
...
@@ -450,6 +453,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -450,6 +453,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
return
self
.
_mmcifs
[
idx
]
return
self
.
_mmcifs
[
idx
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
print
(
f
"####### line 456 idx is
{
idx
}
"
)
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
alignment_index
=
None
alignment_index
=
None
...
@@ -741,9 +746,7 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
...
@@ -741,9 +746,7 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
samples
=
samples
.
squeeze
()
samples
=
samples
.
squeeze
()
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
for
datapoint_idx
in
cache
:
for
datapoint_idx
in
cache
:
yield
datapoint_idx
yield
datapoint_idx
...
...
openfold/data/data_pipeline.py
View file @
aec12764
...
@@ -35,6 +35,9 @@ from openfold.np import residue_constants, protein
...
@@ -35,6 +35,9 @@ from openfold.np import residue_constants, protein
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
def
calculate_elapse
(
start
,
end
,
name
):
elapse
=
end
-
start
print
(
f
"
{
name
}
runs
{
round
(
elapse
,
3
)
}
seconds i.e.
{
round
(
elapse
/
60
,
3
)
}
minutes"
)
def
make_template_features
(
def
make_template_features
(
input_sequence
:
str
,
input_sequence
:
str
,
...
@@ -739,13 +742,21 @@ class DataPipeline:
...
@@ -739,13 +742,21 @@ class DataPipeline:
filename
,
ext
=
os
.
path
.
splitext
(
f
)
filename
,
ext
=
os
.
path
.
splitext
(
f
)
if
(
ext
==
".a3m"
):
if
(
ext
==
".a3m"
):
import
time
start
=
time
.
time
()
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"parser.parse_a3m"
)
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
import
time
start
=
time
.
time
()
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
msa
=
parsers
.
parse_stockholm
(
msa
=
parsers
.
parse_stockholm
(
fp
.
read
()
fp
.
read
()
)
)
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"parsers.parse_stockholm"
)
else
:
else
:
continue
continue
...
@@ -825,13 +836,22 @@ class DataPipeline:
...
@@ -825,13 +836,22 @@ class DataPipeline:
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
import
time
start_main
=
time
.
time
()
start
=
time
.
time
()
msas
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
alignment_dir
,
input_sequence
,
alignment_index
alignment_dir
,
input_sequence
,
alignment_index
)
)
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"get_msas"
)
start
=
time
.
time
()
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
msas
=
msas
msas
=
msas
)
)
end
=
time
.
time
()
calculate_elapse
(
start
,
end
,
"make_msa_features"
)
end_main
=
time
.
time
()
calculate_elapse
(
start_main
,
end_main
,
"process_msa_feats"
)
return
msa_features
return
msa_features
# Load and process sequence embedding features
# Load and process sequence embedding features
...
...
openfold/data/parsers.py
View file @
aec12764
...
@@ -20,7 +20,7 @@ import itertools
...
@@ -20,7 +20,7 @@ import itertools
import
re
import
re
import
string
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
import
asyncio
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
...
@@ -120,9 +120,9 @@ def parse_stockholm(stockholm_string: str) -> Msa:
...
@@ -120,9 +120,9 @@ def parse_stockholm(stockholm_string: str) -> Msa:
line
=
line
.
strip
()
line
=
line
.
strip
()
if
not
line
or
line
.
startswith
((
"#"
,
"//"
)):
if
not
line
or
line
.
startswith
((
"#"
,
"//"
)):
continue
continue
name
,
sequence
=
line
.
split
()
name
,
sequence
=
line
.
split
(
maxsplit
=
1
)
if
name
not
in
name_to_sequence
:
if
name
not
in
name_to_sequence
:
name_to_sequence
[
name
]
=
""
name_to_sequence
.
setdefault
(
name
,
""
)
name_to_sequence
[
name
]
+=
sequence
name_to_sequence
[
name
]
+=
sequence
msa
=
[]
msa
=
[]
...
...
train_openfold.py
View file @
aec12764
...
@@ -42,6 +42,9 @@ from scripts.zero_to_fp32 import (
...
@@ -42,6 +42,9 @@ from scripts.zero_to_fp32 import (
from
openfold.utils.logger
import
PerformanceLoggingCallback
from
openfold.utils.logger
import
PerformanceLoggingCallback
def
calculate_elapse
(
start
,
end
):
elapse
=
end
-
start
print
(
f
"this function runs
{
round
(
elapse
,
3
)
}
seconds i.e.
{
round
(
elapse
/
60
,
3
)
}
minutes"
)
class
OpenFoldWrapper
(
pl
.
LightningModule
):
class
OpenFoldWrapper
(
pl
.
LightningModule
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -316,7 +319,7 @@ def main(args):
...
@@ -316,7 +319,7 @@ def main(args):
batch_seed
=
args
.
seed
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
**
vars
(
args
)
)
)
data_module
.
prepare_data
()
data_module
.
prepare_data
()
data_module
.
setup
()
data_module
.
setup
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment