Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
127f1e70
Unverified
Commit
127f1e70
authored
Mar 21, 2024
by
Jennifer Wei
Committed by
GitHub
Mar 21, 2024
Browse files
Merge pull request #418 from aqlaboratory/seeding-fix
Fix distributed seeding behavior
parents
ef0c9fac
a56ea9b5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
2 additions
and
47 deletions
+2
-47
openfold/utils/seed.py
openfold/utils/seed.py
+0
-19
openfold/utils/suppress_output.py
openfold/utils/suppress_output.py
+0
-26
train_openfold.py
train_openfold.py
+2
-2
No files found.
openfold/utils/seed.py
deleted
100644 → 0
View file @
ef0c9fac
import
os
import
logging
import
random
import
numpy
as
np
from
pytorch_lightning.utilities.seed
import
seed_everything
from
openfold.utils.suppress_output
import
SuppressLogging
def
seed_globally
(
seed
=
None
):
if
(
"PL_GLOBAL_SEED"
not
in
os
.
environ
):
if
(
seed
is
None
):
seed
=
random
.
randint
(
0
,
np
.
iinfo
(
np
.
uint32
).
max
)
os
.
environ
[
"PL_GLOBAL_SEED"
]
=
str
(
seed
)
logging
.
info
(
f
'os.environ["PL_GLOBAL_SEED"] set to
{
seed
}
'
)
# seed_everything is a bit log-happy
with
SuppressLogging
(
logging
.
INFO
):
seed_everything
(
seed
=
None
)
openfold/utils/suppress_output.py
deleted
100644 → 0
View file @
ef0c9fac
import
logging
import
sys
class
SuppressStdout
:
def
__enter__
(
self
):
self
.
stdout
=
sys
.
stdout
dev_null
=
open
(
"/dev/null"
,
"w"
)
sys
.
stdout
=
dev_null
def
__exit__
(
self
,
typ
,
value
,
traceback
):
fp
=
sys
.
stdout
sys
.
stdout
=
self
.
stdout
fp
.
close
()
class
SuppressLogging
:
def
__init__
(
self
,
level
):
self
.
level
=
level
def
__enter__
(
self
):
logging
.
disable
(
self
.
level
)
def
__exit__
(
self
,
typ
,
value
,
traceback
):
logging
.
disable
(
logging
.
NOTSET
)
train_openfold.py
View file @
127f1e70
...
@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
...
@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.utilities.seed
import
seed_everything
import
torch
import
torch
from
openfold.config
import
model_config
from
openfold.config
import
model_config
...
@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
...
@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
from
openfold.utils.validation_metrics
import
(
...
@@ -272,7 +272,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -272,7 +272,7 @@ class OpenFoldWrapper(pl.LightningModule):
def
main
(
args
):
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
seed_everything
(
args
.
seed
,
workers
=
True
)
config
=
model_config
(
config
=
model_config
(
args
.
config_preset
,
args
.
config_preset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment