Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
c4a4df22
"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "2fa93c6946e84a41156f37001df5dd2e56e4f7b5"
Commit
c4a4df22
authored
May 09, 2022
by
Gustaf Ahdritz
Browse files
Trim OpenFoldBatchCollator
parent
954ed3d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
28 deletions
+52
-28
openfold/data/data_modules.py
openfold/data/data_modules.py
+52
-28
No files found.
openfold/data/data_modules.py
View file @
c4a4df22
...
@@ -37,7 +37,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -37,7 +37,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path
:
Optional
[
str
]
=
None
,
mapping_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
_output_raw
:
bool
=
False
,
_output_raw
:
bool
=
False
,
_alignment_index
:
Optional
[
Any
]
=
None
_structure_index
:
Optional
[
Any
]
=
None
,
_alignment_index
:
Optional
[
Any
]
=
None
,
):
):
"""
"""
Args:
Args:
...
@@ -84,8 +85,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -84,8 +85,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
mode
=
mode
self
.
mode
=
mode
self
.
_output_raw
=
_output_raw
self
.
_output_raw
=
_output_raw
self
.
_structure_index
=
_structure_index
self
.
_alignment_index
=
_alignment_index
self
.
_alignment_index
=
_alignment_index
self
.
supported_exts
=
[
".cif"
,
".core"
,
".pdb"
]
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
if
(
mode
not
in
valid_modes
):
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
...
@@ -103,7 +107,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -103,7 +107,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
else
:
else
:
with
open
(
mapping_path
,
"r"
)
as
f
:
with
open
(
mapping_path
,
"r"
)
as
f
:
self
.
_chain_ids
=
[
l
.
strip
()
for
l
in
f
.
readlines
()]
self
.
_chain_ids
=
[
l
.
strip
()
for
l
in
f
.
readlines
()]
self
.
_chain_id_to_idx_dict
=
{
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
}
}
...
@@ -173,24 +177,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -173,24 +177,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id
=
None
chain_id
=
None
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
if
(
os
.
path
.
exists
(
path
+
".cif"
)):
structure_index_entry
=
None
if
(
self
.
_structure_index
is
not
None
):
structure_index_entry
=
self
.
_structure_index
[
name
]
assert
(
len
(
structure_index_entry
[
"files"
])
==
1
)
filename
,
_
,
_
=
structure_index_entry
[
"files"
][
0
]
ext
=
os
.
path
.
splitext
(
filename
)[
1
]
else
:
ext
=
None
for
e
in
self
.
supported_exts
:
if
(
os
.
path
.
exists
(
path
+
e
)):
ext
=
e
break
if
(
ext
is
None
):
raise
ValueError
(
"Invalid file type"
)
path
+=
ext
if
(
ext
==
".cif"
):
data
=
self
.
_parse_mmcif
(
data
=
self
.
_parse_mmcif
(
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
,
path
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
,
)
)
elif
(
os
.
path
.
exists
(
path
+
".core"
)
)
:
elif
(
ext
==
".core"
):
data
=
self
.
data_pipeline
.
process_core
(
data
=
self
.
data_pipeline
.
process_core
(
path
+
".core"
,
alignment_dir
,
_alignment_index
,
path
,
alignment_dir
,
_alignment_index
,
)
)
elif
(
os
.
path
.
exists
(
path
+
".pdb"
)
)
:
elif
(
ext
==
".pdb"
):
data
=
self
.
data_pipeline
.
process_pdb
(
data
=
self
.
data_pipeline
.
process_pdb
(
pdb_path
=
path
+
".pdb"
,
pdb_path
=
path
,
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
chain_id
=
chain_id
,
chain_id
=
chain_id
,
_structure_index
=
self
.
_structure_index
[
name
],
_alignment_index
=
_alignment_index
,
_alignment_index
=
_alignment_index
,
)
)
else
:
else
:
raise
ValueError
(
"
Invalid file type
"
)
raise
ValueError
(
"
Extension branch missing
"
)
else
:
else
:
path
=
os
.
path
.
join
(
name
,
name
+
".fasta"
)
path
=
os
.
path
.
join
(
name
,
name
+
".fasta"
)
data
=
self
.
data_pipeline
.
process_fasta
(
data
=
self
.
data_pipeline
.
process_fasta
(
...
@@ -206,6 +228,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -206,6 +228,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
,
self
.
mode
data
,
self
.
mode
)
)
feats
[
"batch_idx"
]
=
torch
.
tensor
([
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
return
feats
return
feats
def
__len__
(
self
):
def
__len__
(
self
):
...
@@ -355,20 +379,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -355,20 +379,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
class
OpenFoldBatchCollator
:
class
OpenFoldBatchCollator
:
def
__init__
(
self
,
config
,
stage
=
"train"
):
def
__call__
(
self
,
prots
):
self
.
stage
=
stage
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
__call__
(
self
,
raw_prots
):
processed_prots
=
[]
for
prot
in
raw_prots
:
features
=
self
.
feature_pipeline
.
process_features
(
prot
,
self
.
stage
)
processed_prots
.
append
(
features
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
return
dict_multimap
(
stack_fn
,
processed_
prots
)
return
dict_multimap
(
stack_fn
,
prots
)
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
...
@@ -486,7 +499,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -486,7 +499,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
train_epoch_len
:
int
=
50000
,
_distillation_structure_index_path
:
Optional
[
str
]
=
None
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
_distillation_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
**
kwargs
):
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
super
(
OpenFoldDataModule
,
self
).
__init__
()
...
@@ -539,11 +554,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -539,11 +554,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
# An ad-hoc measure for our particular filesystem restrictions
# An ad-hoc measure for our particular filesystem restrictions
self
.
_distillation_structure_index
=
None
if
(
_distillation_structure_index_path
is
not
None
):
with
open
(
_distillation_structure_index_path
,
"r"
)
as
fp
:
self
.
_distillation_structure_index
=
json
.
load
(
fp
)
self
.
_alignment_index
=
None
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
self
.
_alignment_index
=
json
.
load
(
fp
)
self
.
_distillation_alignment_index
=
None
if
(
_distillation_alignment_index_path
is
not
None
):
with
open
(
_distillation_alignment_index_path
,
"r"
)
as
fp
:
self
.
_distillation_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
...
@@ -567,7 +592,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -567,7 +592,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
_alignment_index
=
self
.
_alignment_index
,
_alignment_index
=
self
.
_alignment_index
,
)
)
...
@@ -577,10 +601,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -577,10 +601,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
data_dir
=
self
.
distillation_data_dir
,
data_dir
=
self
.
distillation_data_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
mapping_path
=
self
.
distillation_mapping_path
,
mapping_path
=
self
.
distillation_mapping_path
,
max_template_hits
=
self
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
treat_pdb_as_distillation
=
True
,
treat_pdb_as_distillation
=
True
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
_structure_index
=
self
.
_distillation_structure_index
,
_alignment_index
=
self
.
_distillation_alignment_index
,
)
)
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
self
.
config
.
train
.
distillation_prob
...
@@ -588,7 +613,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -588,7 +613,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
distillation_dataset
is
not
None
):
if
(
distillation_dataset
is
not
None
):
datasets
=
[
train_dataset
,
distillation_dataset
]
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
probabilities
=
[
1
.
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
self
.
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
...
@@ -615,7 +640,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -615,7 +640,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
mapping_path
=
None
,
mapping_path
=
None
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
mode
=
"eval"
,
_output_raw
=
True
,
)
)
else
:
else
:
self
.
eval_dataset
=
None
self
.
eval_dataset
=
None
...
@@ -646,7 +670,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -646,7 +670,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
else
:
else
:
raise
ValueError
(
"Invalid stage"
)
raise
ValueError
(
"Invalid stage"
)
batch_collator
=
OpenFoldBatchCollator
(
self
.
config
,
stage
)
batch_collator
=
OpenFoldBatchCollator
()
dl
=
OpenFoldDataLoader
(
dl
=
OpenFoldDataLoader
(
dataset
,
dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment