Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b9faee76
Commit
b9faee76
authored
Feb 10, 2022
by
Gustaf Ahdritz
Browse files
Change name of prot data cache
parent
2864b7ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
28 deletions
+28
-28
README.md
README.md
+1
-1
openfold/data/data_modules.py
openfold/data/data_modules.py
+25
-25
train_openfold.py
train_openfold.py
+2
-2
No files found.
README.md
View file @
b9faee76
...
...
@@ -215,7 +215,7 @@ python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ \
--deepspeed_config_path
deepspeed_config.json
\
--checkpoint_every_epoch
\
--resume_from_ckpt
ckpt_dir/
\
--train_
prot
_data_cache_path
chain_data_cache.json
--train_
chain
_data_cache_path
chain_data_cache.json
```
where
`--template_release_dates_cache_path`
is a path to the
`.json`
file
...
...
openfold/data/data_modules.py
View file @
b9faee76
...
...
@@ -201,16 +201,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def
deterministic_train_filter
(
prot
_data_cache_entry
:
Any
,
chain
_data_cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
# Hard filters
resolution
=
prot
_data_cache_entry
.
get
(
"resolution"
,
None
)
resolution
=
chain
_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
):
return
False
seq
=
prot
_data_cache_entry
[
"seq"
]
seq
=
chain
_data_cache_entry
[
"seq"
]
counts
=
{}
for
aa
in
seq
:
counts
.
setdefault
(
aa
,
0
)
...
...
@@ -224,16 +224,16 @@ def deterministic_train_filter(
def
get_stochastic_train_filter_prob
(
prot
_data_cache_entry
:
Any
,
chain
_data_cache_entry
:
Any
,
)
->
List
[
float
]:
# Stochastic filters
probabilities
=
[]
cluster_size
=
prot
_data_cache_entry
.
get
(
"cluster_size"
,
None
)
cluster_size
=
chain
_data_cache_entry
.
get
(
"cluster_size"
,
None
)
if
(
cluster_size
is
not
None
and
cluster_size
>
0
):
probabilities
.
append
(
1
/
cluster_size
)
chain_length
=
len
(
prot
_data_cache_entry
[
"seq"
])
chain_length
=
len
(
chain
_data_cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
# Risk of underflow here?
...
...
@@ -255,7 +255,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
prot
_data_cache_paths
:
List
[
str
],
chain
_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
...
...
@@ -264,10 +264,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
prot
_data_caches
=
[]
for
path
in
prot
_data_cache_paths
:
self
.
chain
_data_caches
=
[]
for
path
in
chain
_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
prot
_data_caches
.
append
(
json
.
load
(
fp
))
self
.
chain
_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
...
...
@@ -286,19 +286,19 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
prot
_data_cache
=
self
.
prot
_data_caches
[
dataset_idx
]
chain
_data_cache
=
self
.
chain
_data_caches
[
dataset_idx
]
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
prot
_data_cache_entry
=
prot
_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
prot
_data_cache_entry
)):
chain
_data_cache_entry
=
chain
_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
chain
_data_cache_entry
)):
continue
p
=
get_stochastic_train_filter_prob
(
prot
_data_cache_entry
,
chain
_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
...
...
@@ -459,10 +459,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_
prot
_data_cache_path
:
Optional
[
str
]
=
None
,
train_
chain
_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_
prot
_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_
chain
_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
...
...
@@ -483,11 +483,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_
prot
_data_cache_path
=
train_
prot
_data_cache_path
self
.
train_
chain
_data_cache_path
=
train_
chain
_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_
prot
_data_cache_path
=
(
distillation_
prot
_data_cache_path
self
.
distillation_
chain
_data_cache_path
=
(
distillation_
chain
_data_cache_path
)
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
...
...
@@ -569,22 +569,22 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
prot
_data_cache_paths
=
[
self
.
train_
prot
_data_cache_path
,
self
.
distillation_
prot
_data_cache_path
,
chain
_data_cache_paths
=
[
self
.
train_
chain
_data_cache_path
,
self
.
distillation_
chain
_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
prot
_data_cache_paths
=
[
self
.
train_
prot
_data_cache_path
,
chain
_data_cache_paths
=
[
self
.
train_
chain
_data_cache_path
,
]
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
prot
_data_cache_paths
=
prot
_data_cache_paths
,
chain
_data_cache_paths
=
chain
_data_cache_paths
,
_roll_at_init
=
False
,
)
...
...
train_openfold.py
View file @
b9faee76
...
...
@@ -358,10 +358,10 @@ if __name__ == "__main__":
help
=
"Whether to TorchScript eligible components of them model"
)
parser
.
add_argument
(
"--train_
prot
_data_cache_path"
,
type
=
str
,
default
=
None
,
"--train_
chain
_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--distillation_
prot
_data_cache_path"
,
type
=
str
,
default
=
None
,
"--distillation_
chain
_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment