Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
676b6668
"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "58bca13f1b1ef43fdabdcff88438a6c0e4aa31d8"
Commit
676b6668
authored
Feb 01, 2022
by
Gustaf Ahdritz
Browse files
Make filtering more efficient
parent
fb341b17
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
52 deletions
+55
-52
openfold/data/data_modules.py
openfold/data/data_modules.py
+52
-47
train_openfold.py
train_openfold.py
+3
-5
No files found.
openfold/data/data_modules.py
View file @
676b6668
...
...
@@ -212,9 +212,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return
len
(
self
.
_chain_ids
)
def
train_filter
(
def
deterministic_
train_filter
(
prot_data_cache_entry
:
Any
,
generator
:
torch
.
Generator
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
...
...
@@ -233,6 +232,12 @@ def train_filter(
if
(
largest_single_aa_prop
>
max_single_aa_prop
):
return
False
return
True
def
get_stochastic_train_filter_prob
(
prot_data_cache_entry
:
Any
,
)
->
List
[
float
]:
# Stochastic filters
probabilities
=
[]
...
...
@@ -243,14 +248,12 @@ def train_filter(
chain_length
=
len
(
prot_data_cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
weights
=
[[
1
-
p
,
p
]
for
p
in
probabilities
]
results
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
generator
,
)
# Risk of underflow here?
out
=
1
for
p
in
probabilities
:
out
*=
p
return
torch
.
all
(
results
)
return
out
class
OpenFoldDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
@@ -265,7 +268,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
prot_data_cache_paths
:
List
[
str
],
filter_fn
:
Optional
[
Any
]
=
train_filter
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
...
...
@@ -273,8 +275,12 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
probabilities
=
probabilities
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
filter_fn
=
filter_fn
self
.
prot_data_caches
=
[]
for
path
in
prot_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
prot_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
# Uniformly shuffle each dataset's indices
...
...
@@ -288,16 +294,40 @@ class OpenFoldDataset(torch.utils.data.Dataset):
for
idx
in
shuf
:
yield
idx
self
.
shuffled_idx_iters
=
[]
for
d
in
datasets
:
self
.
shuffled_idx_iters
.
append
(
looped_shuffled_dataset_idx
(
len
(
d
))
)
def
looped_samples
(
dataset_idx
):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
prot_data_cache
=
self
.
prot_data_caches
[
dataset_idx
]
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
prot_data_cache_entry
=
prot_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
prot_data_cache_entry
)):
continue
p
=
get_stochastic_train_filter_prob
(
prot_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
self
.
generator
,
)
samples
=
samples
.
squeeze
()
self
.
prot_data_caches
=
[]
for
path
in
prot_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
prot_data_caches
.
append
(
json
.
load
(
fp
))
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
for
datapoint_idx
in
cache
:
yield
datapoint_idx
self
.
_samples
=
[
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
if
(
_roll_at_init
):
self
.
reroll
()
...
...
@@ -319,15 +349,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
self
.
shuffled_idx_iters
[
dataset_idx
]
prot_data_cache
=
self
.
prot_data_caches
[
dataset_idx
]
datapoint_idx
=
None
while
datapoint_idx
is
None
:
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
if
(
self
.
filter_fn
(
prot_data_cache
[
chain_id
],
self
.
generator
)):
datapoint_idx
=
candidate_idx
samples
=
self
.
_samples
[
dataset_idx
]
datapoint_idx
=
next
(
samples
)
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
...
...
@@ -448,7 +471,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_filter_fn
:
Optional
[
Any
]
=
train_filter
,
train_prot_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
...
...
@@ -474,7 +496,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_filter_fn
=
train_filter_fn
self
.
train_prot_data_cache_path
=
train_prot_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
...
...
@@ -517,21 +538,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
cache_missing
=
(
train_filter_fn
and
(
train_prot_data_cache_path
is
None
or
(
distillation_data_dir
is
not
None
and
distillation_prot_data_cache_path
is
None
)
)
)
if
(
cache_missing
):
raise
ValueError
(
"If train_filter_fn is given, so must the protein data caches"
)
# An ad-hoc measure for our particular filesystem restrictions
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
...
...
@@ -599,7 +605,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
prot_data_cache_paths
=
prot_data_cache_paths
,
filter_fn
=
self
.
train_filter_fn
,
_roll_at_init
=
False
,
)
...
...
train_openfold.py
View file @
676b6668
...
...
@@ -68,7 +68,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"train/loss"
,
loss
,
logger
=
True
)
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
)
return
loss
...
...
@@ -151,9 +151,9 @@ def main(args):
if
(
args
.
checkpoint_best_val
):
checkpoint_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoints"
)
mc
=
ModelCheckpoint
(
dirpath
=
checkpoint_dir
,
filename
=
"openfold_{epoch}_{step}_{val_loss:.2f}"
,
monitor
=
"val/loss"
,
mode
=
"max"
,
)
callbacks
.
append
(
mc
)
...
...
@@ -200,6 +200,7 @@ def main(args):
)
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
...
...
@@ -373,9 +374,6 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
)
parser
.
add_argument
(
"--obsolete_pdbs_file_path"
,
type
=
str
,
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment