Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b55ad675
"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "099769d2ecfd01a8baa8d950030df454a042c910"
Commit
b55ad675
authored
Jul 10, 2023
by
Geoffrey Yu
Browse files
finished constructing OpenFoldMultimerDataset filtering and sampling steps
parent
71fdc063
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
56 deletions
+17
-56
openfold/data/data_modules.py
openfold/data/data_modules.py
+17
-56
No files found.
openfold/data/data_modules.py
View file @
b55ad675
...
@@ -212,6 +212,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -212,6 +212,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
name
=
self
.
idx_to_chain_id
(
idx
)
name
=
self
.
idx_to_chain_id
(
idx
)
print
(
f
"name is
{
name
}
"
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_index
=
None
alignment_index
=
None
...
@@ -371,7 +372,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -371,7 +372,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
if
mmcif_data_cache_path
is
not
None
:
if
mmcif_data_cache_path
is
not
None
:
print
(
f
"mmcif_data_cache_path is
{
mmcif_data_cache_path
}
"
)
with
open
(
mmcif_data_cache_path
,
"r"
)
as
infile
:
with
open
(
mmcif_data_cache_path
,
"r"
)
as
infile
:
self
.
mmcif_data_cache
=
json
.
load
(
infile
)
self
.
mmcif_data_cache
=
json
.
load
(
infile
)
assert
isinstance
(
self
.
mmcif_data_cache
,
dict
)
assert
isinstance
(
self
.
mmcif_data_cache
,
dict
)
...
@@ -678,7 +678,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -678,7 +678,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
idx
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
candidate_idx
=
next
(
idx_iter
)
## TO DO: add filtering cretieria for multimer
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
chain_data_cache_entry
=
chain_data_cache
[
chain_id
]
chain_data_cache_entry
=
chain_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
chain_data_cache_entry
)):
if
(
not
deterministic_train_filter
(
chain_data_cache_entry
)):
...
@@ -703,7 +702,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -703,7 +702,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
yield
datapoint_idx
yield
datapoint_idx
self
.
_samples
=
[
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
self
.
_samples
=
[
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
if
(
_roll_at_init
):
if
(
_roll_at_init
):
self
.
reroll
()
self
.
reroll
()
...
@@ -721,13 +719,11 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -721,13 +719,11 @@ class OpenFoldDataset(torch.utils.data.Dataset):
replacement
=
True
,
replacement
=
True
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
self
.
datapoints
=
[]
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
for
dataset_idx
in
dataset_choices
:
samples
=
self
.
_samples
[
dataset_idx
]
samples
=
self
.
_samples
[
dataset_idx
]
datapoint_idx
=
next
(
samples
)
datapoint_idx
=
next
(
samples
)
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
print
(
f
"datapoints is
{
self
.
datapoints
}
"
)
class
OpenFoldMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
...
@@ -747,59 +743,23 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
...
@@ -747,59 +743,23 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self
.
probabilities
=
probabilities
self
.
probabilities
=
probabilities
self
.
epoch_len
=
epoch_len
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
generator
=
generator
self
.
_samples
=
[
self
.
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
if
_roll_at_init
:
if
_roll_at_init
:
self
.
reroll
()
self
.
reroll
()
def
looped_shuffled_dataset_idx
(
self
,
dataset_len
):
def
filter_samples
(
self
,
dataset_idx
):
while
True
:
# Uniformly shuffle each dataset's indices
weights
=
[
1.
for
_
in
range
(
dataset_len
)]
shuf
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
dataset_len
,
replacement
=
False
,
generator
=
self
.
generator
,
)
for
idx
in
shuf
:
yield
idx
def
looped_samples
(
self
,
dataset_idx
):
max_cache_len
=
int
(
self
.
epoch_len
*
self
.
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
self
.
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
dataset
.
chain_data_cache
mmcif_data_cache
=
dataset
.
mmcif_data_cache
mmcif_data_cache
=
dataset
.
mmcif_data_cache
while
True
:
selected_idx
=
[]
weights
=
[]
for
i
in
range
(
len
(
mmcif_data_cache
)):
idx
=
[]
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
i
)
for
_
in
range
(
max_cache_len
):
print
(
f
"mmcif_id is
{
mmcif_id
}
and candidate_idx:
{
i
}
"
)
candidate_idx
=
next
(
idx_iter
)
## TO DO: add filtering cretieria for multimer
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
candidate_idx
)
print
(
f
"mmcif_id is
{
mmcif_id
}
and candidate_idx:
{
candidate_idx
}
"
)
chains
=
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
chains
=
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
(
not
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
)):
if
(
len
(
chains
)
>
1
)
and
(
not
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
continue
max_resolution
=
9
)):
selected_idx
.
append
(
i
)
p
=
get_stochastic_train_filter_prob
(
chain_data_cache_entry
,
return
selected_idx
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
self
.
generator
,
)
samples
=
samples
.
squeeze
()
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
for
datapoint_idx
in
cache
:
yield
datapoint_idx
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
dataset_idx
,
datapoint_idx
=
self
.
datapoints
[
idx
]
dataset_idx
,
datapoint_idx
=
self
.
datapoints
[
idx
]
...
@@ -811,20 +771,21 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
...
@@ -811,20 +771,21 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
def
reroll
(
self
):
def
reroll
(
self
):
dataset_choices
=
torch
.
multinomial
(
dataset_choices
=
torch
.
multinomial
(
torch
.
tensor
(
self
.
probabilities
),
torch
.
tensor
(
self
.
probabilities
),
num_samples
=
self
.
epoch_len
,
num_samples
=
len
(
self
.
probabilities
)
,
replacement
=
True
,
replacement
=
True
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
self
.
datapoints
=
[]
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
for
dataset_idx
in
dataset_choices
:
samples
=
self
.
_samples
[
dataset_idx
]
selected_idx
=
self
.
filter_samples
(
dataset_idx
)
datapoint_idx
=
next
(
samples
)
if
len
(
selected_idx
)
<
self
.
epoch_len
:
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
self
.
epoch_len
=
len
(
selected_idx
)
self
.
datapoints
=
[(
dataset_idx
,
datapoint_idx
)
for
datapoint_idx
in
range
(
self
.
epoch_len
)
]
print
(
f
"datapoints is
{
self
.
datapoints
}
"
)
print
(
f
"datapoints is
{
self
.
datapoints
}
"
)
class
OpenFoldBatchCollator
:
class
OpenFoldBatchCollator
:
def
__call__
(
self
,
prots
):
def
__call__
(
self
,
prots
):
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment