Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
8470b803
Commit
8470b803
authored
Oct 16, 2023
by
Christina Floristean
Browse files
Fix multimer sampling
parent
14853379
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
9 deletions
+9
-9
openfold/data/data_modules.py
openfold/data/data_modules.py
+9
-9
No files found.
openfold/data/data_modules.py
View file @
8470b803
...
@@ -694,14 +694,14 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
...
@@ -694,14 +694,14 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
@
staticmethod
@
staticmethod
def
get_stochastic_train_filter_prob
(
def
get_stochastic_train_filter_prob
(
cache_entry
:
Any
,
cache_entry
:
Any
,
)
->
floa
t
:
)
->
lis
t
:
# Stochastic filters
# Stochastic filters
cluster_sizes
=
cache_entry
.
get
(
"cluster_sizes"
,
[])
cluster_sizes
=
cache_entry
.
get
(
"cluster_sizes"
)
chain_probs
=
[
1
/
c
for
c
in
cluster_sizes
if
c
>
0
]
if
cluster_sizes
is
not
None
:
if
chain_probs
:
return
[
1
/
c
if
c
>
0
else
1
for
c
in
cluster_sizes
]
return
sum
(
chain_probs
)
return
1.
num_chains
=
len
(
cache_entry
[
"chain_ids"
])
return
[
1.
]
*
num_chains
def
looped_samples
(
self
,
dataset_idx
):
def
looped_samples
(
self
,
dataset_idx
):
max_cache_len
=
int
(
self
.
epoch_len
*
self
.
probabilities
[
dataset_idx
])
max_cache_len
=
int
(
self
.
epoch_len
*
self
.
probabilities
[
dataset_idx
])
...
@@ -718,11 +718,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
...
@@ -718,11 +718,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
if
not
self
.
deterministic_train_filter
(
mmcif_data_cache_entry
):
if
not
self
.
deterministic_train_filter
(
mmcif_data_cache_entry
):
continue
continue
p
=
self
.
get_stochastic_train_filter_prob
(
chain_probs
=
self
.
get_stochastic_train_filter_prob
(
mmcif_data_cache_entry
,
mmcif_data_cache_entry
,
)
)
weights
.
app
end
([
1.
-
p
,
p
])
weights
.
ext
end
([
[
1.
-
p
,
p
]
for
p
in
chain_probs
]
)
idx
.
app
end
(
candidate_idx
)
idx
.
ext
end
(
[
candidate_idx
]
*
len
(
chain_probs
)
)
samples
=
torch
.
multinomial
(
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
torch
.
tensor
(
weights
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment