Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1dc2748c
Commit
1dc2748c
authored
Aug 09, 2023
by
Christina Floristean
Browse files
Additional changes to seeding
parent
39d0ef43
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
16 deletions
+14
-16
openfold/data/data_modules.py
openfold/data/data_modules.py
+3
-7
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+10
-7
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+1
-2
No files found.
openfold/data/data_modules.py
View file @
1dc2748c
...
@@ -7,7 +7,6 @@ import pickle
...
@@ -7,7 +7,6 @@ import pickle
from
typing
import
Optional
,
Sequence
,
List
,
Any
from
typing
import
Optional
,
Sequence
,
List
,
Any
import
ml_collections
as
mlc
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch.utils.data
import
RandomSampler
from
torch.utils.data
import
RandomSampler
...
@@ -428,10 +427,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -428,10 +427,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
config
=
config
self
.
stage
=
stage
self
.
stage
=
stage
if
(
generator
is
None
):
generator
=
torch
.
Generator
()
self
.
generator
=
generator
self
.
generator
=
generator
self
.
_prep_batch_properties_probs
()
self
.
_prep_batch_properties_probs
()
...
@@ -687,8 +682,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -687,8 +682,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
):
generator
=
torch
.
Generator
()
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
if
(
self
.
batch_seed
is
not
None
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
dataset
=
None
...
...
openfold/data/data_transforms.py
View file @
1dc2748c
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
import
itertools
import
itertools
from
functools
import
reduce
,
wraps
from
functools
import
reduce
,
wraps
from
operator
import
add
from
operator
import
add
import
random
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -184,11 +183,13 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
...
@@ -184,11 +183,13 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@
curry1
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
,
seed
=
None
):
def
sample_msa
(
protein
,
max_seq
,
keep_extra
,
seed
=
None
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
if
(
seed
is
None
):
seed
=
random
.
randint
(
0
,
2147483647
)
num_seq
=
protein
[
"msa"
].
shape
[
0
]
num_seq
=
protein
[
"msa"
].
shape
[
0
]
g
=
None
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
index_order
=
torch
.
cat
(
index_order
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
shuffled
.
device
),
shuffled
),
(
torch
.
tensor
([
0
],
device
=
shuffled
.
device
),
shuffled
),
...
@@ -1141,8 +1142,10 @@ def random_crop_to_size(
...
@@ -1141,8 +1142,10 @@ def random_crop_to_size(
):
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
g
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
seq_length
=
protein
[
"seq_length"
]
seq_length
=
protein
[
"seq_length"
]
...
...
openfold/data/input_pipeline.py
View file @
1dc2748c
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
random
import
random
import
torch
import
torch
...
@@ -154,7 +153,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
...
@@ -154,7 +153,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
random
.
randint
(
0
,
2147483647
)
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
def
wrap_ensemble_fn
(
data
,
i
):
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
"""Function to be mapped over the ensemble dimension."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment