Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4208a761
"lib/vscode:/vscode.git/clone" did not exist on "f465aca39c8e865f7ee13194bd858113dc133566"
Unverified
Commit
4208a761
authored
Sep 15, 2023
by
Christina Floristean
Committed by
GitHub
Sep 15, 2023
Browse files
Merge pull request #346 from dingquanyu/permutation
Update validation
parents
61cbc7d3
bc49758a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
22 deletions
+30
-22
openfold/data/data_modules.py
openfold/data/data_modules.py
+26
-22
train_openfold.py
train_openfold.py
+4
-0
No files found.
openfold/data/data_modules.py
View file @
4208a761
...
...
@@ -23,6 +23,7 @@ import tempfile
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
import
random
@
contextlib
.
contextmanager
...
...
@@ -368,15 +369,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
"""
super
(
OpenFoldSingleMultimerDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
mmcif_data_cache_path
=
mmcif_data_cache_path
self
.
chain_data_cache
=
None
if
chain_data_cache_path
is
not
None
:
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
if
mmcif_data_cache_path
is
not
None
:
with
open
(
mmcif_data_cache_path
,
"r"
)
as
infile
:
if
self
.
mmcif_data_cache_path
is
not
None
:
with
open
(
self
.
mmcif_data_cache_path
,
"r"
)
as
infile
:
self
.
mmcif_data_cache
=
json
.
load
(
infile
)
assert
isinstance
(
self
.
mmcif_data_cache
,
dict
)
...
...
@@ -413,9 +414,12 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
if
self
.
mmcif_data_cache
is
not
None
:
if
self
.
mmcif_data_cache_path
is
not
None
:
self
.
_mmcifs
=
list
(
self
.
mmcif_data_cache
.
keys
())
elif
self
.
mmcif_data_cache_path
is
None
and
self
.
alignment_dir
is
not
None
:
self
.
_mmcifs
=
[
i
.
split
(
"_"
)[
0
]
for
i
in
os
.
listdir
(
self
.
alignment_dir
)]
else
:
raise
ValueError
(
"You must provide at least one of the mmcif_data_cache or alignment_dir"
)
self
.
_mmcif_id_to_idx_dict
=
{
mmcif
:
i
for
i
,
mmcif
in
enumerate
(
self
.
_mmcifs
)
}
...
...
@@ -470,9 +474,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
chains
=
self
.
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
print
(
f
"mmcif_id is :
{
mmcif_id
}
idx:
{
idx
}
and has
{
len
(
chains
)
}
chains"
)
alignment_index
=
None
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
...
...
@@ -715,17 +716,18 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
def
filter_samples
(
self
,
dataset_idx
):
dataset
=
self
.
datasets
[
dataset_idx
]
mmcif_data_cache
=
dataset
.
mmcif_data_cache
mmcif_data_cache
=
dataset
.
mmcif_data_cache
if
hasattr
(
dataset
,
"mmcif_data_cache"
)
else
None
selected_idx
=
[]
if
mmcif_data_cache
is
not
None
:
for
i
in
range
(
len
(
mmcif_data_cache
)):
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
i
)
chains
=
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
,
minimum_number_of_residues
=
5
):
selected_idx
.
append
(
i
)
else
:
selected_idx
=
list
(
range
(
len
(
dataset
.
_mmcif_id_to_idx_dict
)))
return
selected_idx
def
__getitem__
(
self
,
idx
):
...
...
@@ -746,6 +748,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
selected_idx
=
self
.
filter_samples
(
dataset_idx
)
random
.
shuffle
(
selected_idx
)
if
len
(
selected_idx
)
<
self
.
epoch_len
:
self
.
epoch_len
=
len
(
selected_idx
)
logging
.
info
(
f
"self.epoch_len is
{
self
.
epoch_len
}
"
)
...
...
@@ -849,7 +852,6 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
self
.
stage
=
stage
self
.
generator
=
generator
print
(
'initialised a multimer dataloader'
)
def
__iter__
(
self
):
it
=
super
().
__iter__
()
...
...
@@ -1092,6 +1094,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
val_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataModule
,
self
).
__init__
(
config
,
template_mmcif_dir
,
...
...
@@ -1099,6 +1102,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_data_dir
,
**
kwargs
)
self
.
train_mmcif_data_cache_path
=
train_mmcif_data_cache_path
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
...
...
@@ -1167,6 +1171,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self
.
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
mmcif_data_cache_path
=
self
.
val_mmcif_data_cache_path
,
filter_path
=
None
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
...
...
@@ -1206,7 +1211,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
batch_size
=
1
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
)
print
(
f
"generated training dataloader"
)
return
dl
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
train_openfold.py
View file @
4208a761
...
...
@@ -509,6 +509,10 @@ if __name__ == "__main__":
"--val_alignment_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed validation alignments"
)
parser
.
add_argument
(
"--val_mmcif_data_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"path to the json file which records all the information of mmcif structures used during validation"
)
parser
.
add_argument
(
"--kalign_binary_path"
,
type
=
str
,
default
=
'/usr/bin/kalign'
,
help
=
"Path to the kalign binary"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment