Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b61e99bc
Commit
b61e99bc
authored
Jul 09, 2023
by
Geoffrey Yu
Browse files
update multimer datasets
parent
33d8de81
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
3 deletions
+28
-3
openfold/data/data_modules.py
openfold/data/data_modules.py
+28
-3
No files found.
openfold/data/data_modules.py
View file @
b61e99bc
...
...
@@ -371,6 +371,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
if
mmcif_data_cache_path
is
not
None
:
print
(
f
"mmcif_data_cache_path is
{
mmcif_data_cache_path
}
"
)
with
open
(
mmcif_data_cache_path
,
"r"
)
as
infile
:
self
.
mmcif_data_cache
=
json
.
load
(
infile
)
assert
isinstance
(
self
.
mmcif_data_cache
,
dict
)
...
...
@@ -747,6 +748,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
_samples
=
[
self
.
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
if
_roll_at_init
:
self
.
reroll
()
def
looped_shuffled_dataset_idx
(
self
,
dataset_len
):
while
True
:
...
...
@@ -774,9 +777,10 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
candidate_idx
=
next
(
idx_iter
)
## TO DO: add filtering cretieria for multimer
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
candidate_idx
)
chains
=
chain_data_cache
[
mmcif_id
][
'chain_ids'
]
print
(
f
"mmcif_id is
{
mmcif_id
}
and candidate_idx:
{
candidate_idx
}
"
)
chains
=
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
(
not
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
)):
if
(
not
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
)):
continue
p
=
get_stochastic_train_filter_prob
(
...
...
@@ -797,6 +801,27 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
for
datapoint_idx
in
cache
:
yield
datapoint_idx
def
__getitem__
(
self
,
idx
):
dataset_idx
,
datapoint_idx
=
self
.
datapoints
[
idx
]
return
self
.
datasets
[
dataset_idx
][
datapoint_idx
]
def
__len__
(
self
):
return
self
.
epoch_len
def
reroll
(
self
):
dataset_choices
=
torch
.
multinomial
(
torch
.
tensor
(
self
.
probabilities
),
num_samples
=
self
.
epoch_len
,
replacement
=
True
,
generator
=
self
.
generator
,
)
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
samples
=
self
.
_samples
[
dataset_idx
]
datapoint_idx
=
next
(
samples
)
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
print
(
f
"datapoints is
{
self
.
datapoints
}
"
)
...
...
@@ -1193,7 +1218,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
generator
=
generator
,
_roll_at_init
=
Fals
e
,
_roll_at_init
=
Tru
e
,
)
if
(
self
.
val_data_dir
is
not
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment