Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4208a761
Unverified
Commit
4208a761
authored
Sep 15, 2023
by
Christina Floristean
Committed by
GitHub
Sep 15, 2023
Browse files
Merge pull request #346 from dingquanyu/permutation
Update validation
parents
61cbc7d3
bc49758a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
22 deletions
+30
-22
openfold/data/data_modules.py
openfold/data/data_modules.py
+26
-22
train_openfold.py
train_openfold.py
+4
-0
No files found.
openfold/data/data_modules.py
View file @
4208a761
...
...
@@ -23,6 +23,7 @@ import tempfile
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
import
random
@
contextlib
.
contextmanager
...
...
@@ -368,15 +369,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
"""
super
(
OpenFoldSingleMultimerDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
mmcif_data_cache_path
=
mmcif_data_cache_path
self
.
chain_data_cache
=
None
if
chain_data_cache_path
is
not
None
:
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
if
mmcif_data_cache_path
is
not
None
:
with
open
(
mmcif_data_cache_path
,
"r"
)
as
infile
:
if
self
.
mmcif_data_cache_path
is
not
None
:
with
open
(
self
.
mmcif_data_cache_path
,
"r"
)
as
infile
:
self
.
mmcif_data_cache
=
json
.
load
(
infile
)
assert
isinstance
(
self
.
mmcif_data_cache
,
dict
)
...
...
@@ -413,13 +414,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
if
self
.
mmcif_data_cache
is
not
None
:
if
self
.
mmcif_data_cache_path
is
not
None
:
self
.
_mmcifs
=
list
(
self
.
mmcif_data_cache
.
keys
())
self
.
_mmcif_id_to_idx_dict
=
{
elif
self
.
mmcif_data_cache_path
is
None
and
self
.
alignment_dir
is
not
None
:
self
.
_mmcifs
=
[
i
.
split
(
"_"
)[
0
]
for
i
in
os
.
listdir
(
self
.
alignment_dir
)]
else
:
raise
ValueError
(
"You must provide at least one of the mmcif_data_cache or alignment_dir"
)
self
.
_mmcif_id_to_idx_dict
=
{
mmcif
:
i
for
i
,
mmcif
in
enumerate
(
self
.
_mmcifs
)
}
# changed template_featurizer to hmmsearch for now just to run the test
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
template_mmcif_dir
,
...
...
@@ -470,9 +474,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
chains
=
self
.
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
print
(
f
"mmcif_id is :
{
mmcif_id
}
idx:
{
idx
}
and has
{
len
(
chains
)
}
chains"
)
alignment_index
=
None
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
...
...
@@ -715,17 +716,18 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
def
filter_samples
(
self
,
dataset_idx
):
dataset
=
self
.
datasets
[
dataset_idx
]
mmcif_data_cache
=
dataset
.
mmcif_data_cache
mmcif_data_cache
=
dataset
.
mmcif_data_cache
if
hasattr
(
dataset
,
"mmcif_data_cache"
)
else
None
selected_idx
=
[]
for
i
in
range
(
len
(
mmcif_data_cache
)):
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
i
)
chains
=
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
,
minimum_number_of_residues
=
5
):
selected_idx
.
append
(
i
)
if
mmcif_data_cache
is
not
None
:
for
i
in
range
(
len
(
mmcif_data_cache
)):
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
i
)
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
,
minimum_number_of_residues
=
5
):
selected_idx
.
append
(
i
)
else
:
selected_idx
=
list
(
range
(
len
(
dataset
.
_mmcif_id_to_idx_dict
)))
return
selected_idx
def
__getitem__
(
self
,
idx
):
...
...
@@ -746,6 +748,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
selected_idx
=
self
.
filter_samples
(
dataset_idx
)
random
.
shuffle
(
selected_idx
)
if
len
(
selected_idx
)
<
self
.
epoch_len
:
self
.
epoch_len
=
len
(
selected_idx
)
logging
.
info
(
f
"self.epoch_len is
{
self
.
epoch_len
}
"
)
...
...
@@ -849,7 +852,6 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
self
.
stage
=
stage
self
.
generator
=
generator
print
(
'initialised a multimer dataloader'
)
def
__iter__
(
self
):
it
=
super
().
__iter__
()
...
...
@@ -1092,6 +1094,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
val_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataModule
,
self
).
__init__
(
config
,
template_mmcif_dir
,
...
...
@@ -1099,6 +1102,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_data_dir
,
**
kwargs
)
self
.
train_mmcif_data_cache_path
=
train_mmcif_data_cache_path
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
...
...
@@ -1167,6 +1171,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self
.
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
mmcif_data_cache_path
=
self
.
val_mmcif_data_cache_path
,
filter_path
=
None
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
...
...
@@ -1206,7 +1211,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
batch_size
=
1
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
)
print
(
f
"generated training dataloader"
)
return
dl
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
train_openfold.py
View file @
4208a761
...
...
@@ -509,6 +509,10 @@ if __name__ == "__main__":
"--val_alignment_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed validation alignments"
)
parser
.
add_argument
(
"--val_mmcif_data_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"path to the json file which records all the information of mmcif structures used during validation"
)
parser
.
add_argument
(
"--kalign_binary_path"
,
type
=
str
,
default
=
'/usr/bin/kalign'
,
help
=
"Path to the kalign binary"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment