Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3d5e8740
Commit
3d5e8740
authored
Dec 27, 2021
by
Gustaf Ahdritz
Browse files
Fix RNG bug
parent
43116de0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
44 deletions
+88
-44
openfold/data/data_modules.py
openfold/data/data_modules.py
+88
-44
No files found.
openfold/data/data_modules.py
View file @
3d5e8740
...
...
@@ -244,12 +244,33 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class
OpenFoldBatchCollator
:
def
__init__
(
self
,
config
,
generator
,
stage
=
"train"
):
self
.
config
=
config
self
.
generator
=
generator
self
.
stage
=
stage
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
_prep_batch_properties_probs
()
def
__call__
(
self
,
raw_prots
):
processed_prots
=
[]
for
prot
in
raw_prots
:
features
=
self
.
feature_pipeline
.
process_features
(
prot
,
self
.
stage
)
processed_prots
.
append
(
features
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
return
dict_multimap
(
stack_fn
,
processed_prots
)
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
stage
=
stage
if
(
generator
is
None
):
generator
=
torch
.
Generator
()
self
.
generator
=
generator
self
.
_prep_batch_properties_probs
()
def
_prep_batch_properties_probs
(
self
):
keyed_probs
=
[]
stage_cfg
=
self
.
config
[
self
.
stage
]
...
...
@@ -259,7 +280,7 @@ class OpenFoldBatchCollator:
clamp_prob
=
self
.
config
.
supervised
.
clamp_prob
keyed_probs
.
append
(
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
)
)
if
(
self
.
config
.
supervised
.
uniform_recycling
):
recycling_probs
=
[
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
...
...
@@ -286,7 +307,7 @@ class OpenFoldBatchCollator:
dtype
=
torch
.
float32
,
)
def
_add_batch_properties
(
self
,
raw_prots
):
def
_add_batch_properties
(
self
,
batch
):
samples
=
torch
.
multinomial
(
self
.
prop_probs_tensor
,
num_samples
=
1
,
# 1 per row
...
...
@@ -294,22 +315,42 @@ class OpenFoldBatchCollator:
generator
=
self
.
generator
)
aatype
=
batch
[
"aatype"
]
batch_dims
=
aatype
.
shape
[:
-
2
]
recycling_dim
=
aatype
.
shape
[
-
1
]
no_recycling
=
recycling_dim
for
i
,
key
in
enumerate
(
self
.
prop_keys
):
sample
=
samples
[
i
][
0
]
for
prot
in
raw_prots
:
prot
[
key
]
=
np
.
array
(
sample
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
raw_prots
):
self
.
_add_batch_properties
(
raw_prots
)
processed_prots
=
[]
for
prot
in
raw_prots
:
features
=
self
.
feature_pipeline
.
process_features
(
prot
,
self
.
stage
sample
=
int
(
samples
[
i
][
0
])
sample_tensor
=
torch
.
tensor
(
sample
,
device
=
aatype
.
device
,
requires_grad
=
False
)
processed_prots
.
append
(
features
)
orig_shape
=
sample_tensor
.
shape
sample_tensor
=
sample_tensor
.
view
(
(
1
,)
*
len
(
batch_dims
)
+
sample_tensor
.
shape
+
(
1
,)
)
sample_tensor
=
sample_tensor
.
expand
(
batch_dims
+
orig_shape
+
(
recycling_dim
,)
)
batch
[
key
]
=
sample_tensor
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
return
dict_multimap
(
stack_fn
,
processed_prots
)
if
(
key
==
"no_recycling_iters"
):
no_recycling
=
sample
resample_recycling
=
lambda
t
:
t
[...,
:
no_recycling
+
1
]
batch
=
tensor_tree_map
(
resample_recycling
,
batch
)
return
batch
def
__iter__
(
self
):
it
=
super
().
__iter__
()
def
_batch_prop_gen
(
iterator
):
for
batch
in
iterator
:
yield
self
.
_add_batch_properties
(
batch
)
return
_batch_prop_gen
(
it
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
...
...
@@ -427,7 +468,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
if
(
self
.
val_data_dir
is
not
None
):
self
.
val_dataset
=
dataset_gen
(
self
.
e
val_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
mapping_path
=
None
,
...
...
@@ -436,7 +477,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
_output_raw
=
True
,
)
else
:
self
.
val_dataset
=
None
self
.
e
val_dataset
=
None
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
...
...
@@ -446,42 +487,45 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
)
def
_gen_batch_collator
(
self
,
stage
):
""" We want each process to use the same batch collation seed """
def
_gen_dataloader
(
self
,
stage
):
generator
=
torch
.
Generator
()
if
(
self
.
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
collate_fn
=
OpenFoldBatchCollator
(
self
.
config
,
generator
,
stage
)
return
collate_fn
def
train_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
train_dataset
,
dataset
=
None
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
elif
(
stage
==
"eval"
):
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
):
dataset
=
self
.
predict_dataset
else
:
raise
ValueError
(
"Invalid stage"
)
batch_collator
=
OpenFoldBatchCollator
(
self
.
config
,
stage
)
dl
=
OpenFoldDataLoader
(
dataset
,
config
=
self
.
config
,
stage
=
stage
,
generator
=
generator
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_
batch_collator
(
"train"
)
,
collate_fn
=
batch_collator
,
)
def
val_dataloader
(
self
):
if
(
self
.
val_dataset
is
not
None
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
val_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
)
return
dl
def
train_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"train"
)
def
val_dataloader
(
self
):
if
(
self
.
eval_dataset
is
not
None
):
return
self
.
_gen_dataloader
(
"eval"
)
return
None
def
predict_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
predict_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"predict"
)
)
return
self
.
_gen_dataloader
(
"predict"
)
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment