Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1715e3d5
Commit
1715e3d5
authored
Oct 22, 2021
by
Gustaf Ahdritz
Browse files
Improve global seeding, clean up training script a little
parent
971c41d2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
12 deletions
+50
-12
openfold/utils/seed.py
openfold/utils/seed.py
+19
-0
openfold/utils/suppress_output.py
openfold/utils/suppress_output.py
+26
-0
train_openfold.py
train_openfold.py
+5
-12
No files found.
openfold/utils/seed.py
0 → 100644
View file @
1715e3d5
import
os
import
logging
import
random
import
numpy
as
np
from
pytorch_lightning.utilities.seed
import
seed_everything
from
openfold.utils.suppress_output
import
SuppressLogging
def
seed_globally
(
seed
=
None
):
if
(
"PL_GLOBAL_SEED"
not
in
os
.
environ
):
if
(
seed
is
None
):
seed
=
random
.
randint
(
0
,
np
.
iinfo
(
np
.
uint32
).
max
)
os
.
environ
[
"PL_GLOBAL_SEED"
]
=
str
(
seed
)
print
(
f
'os.environ["PL_GLOBAL_SEED"] set to
{
seed
}
'
)
# seed_everything is a bit log-happy
with
SuppressLogging
(
logging
.
INFO
):
seed_everything
(
seed
=
None
)
openfold/utils/suppress_output.py
0 → 100644
View file @
1715e3d5
import
logging
import
sys
class
SuppressStdout
:
def
__enter__
(
self
):
self
.
stdout
=
sys
.
stdout
dev_null
=
open
(
"/dev/null"
,
"w"
)
sys
.
stdout
=
dev_null
def
__exit__
(
self
,
typ
,
value
,
traceback
):
fp
=
sys
.
stdout
sys
.
stdout
=
self
.
stdout
fp
.
close
()
class
SuppressLogging
:
def
__init__
(
self
,
level
):
self
.
level
=
level
def
__enter__
(
self
):
logging
.
disable
(
self
.
level
)
def
__exit__
(
self
,
typ
,
value
,
traceback
):
logging
.
disable
(
logging
.
NOTSET
)
train_openfold.py
View file @
1715e3d5
...
@@ -2,7 +2,7 @@ import argparse
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
logging
import
os
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"6"
#
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
#os.environ["NODE_RANK"]="0"
...
@@ -25,10 +25,9 @@ from openfold.data.data_modules import (
...
@@ -25,10 +25,9 @@ from openfold.data.data_modules import (
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
copy
class
OpenFoldWrapper
(
pl
.
LightningModule
):
class
OpenFoldWrapper
(
pl
.
LightningModule
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -91,6 +90,9 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -91,6 +90,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
main
(
args
):
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
config
=
model_config
(
config
=
model_config
(
"model_1"
,
"model_1"
,
train
=
True
,
train
=
True
,
...
@@ -111,9 +113,6 @@ def main(args):
...
@@ -111,9 +113,6 @@ def main(args):
if
(
args
.
deepspeed_config_path
is
not
None
):
if
(
args
.
deepspeed_config_path
is
not
None
):
plugins
.
append
(
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
))
plugins
.
append
(
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
))
#os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
#plugins.append(DDPPlugin(find_unused_parameters=True))
trainer
=
pl
.
Trainer
.
from_argparse_args
(
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
args
,
plugins
=
plugins
,
plugins
=
plugins
,
...
@@ -196,10 +195,4 @@ if __name__ == "__main__":
...
@@ -196,10 +195,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
(
args
.
seed
is
not
None
):
torch
.
manual_seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
+
1
)
np
.
random
.
seed
(
args
.
seed
+
2
)
args
.
seed
+=
1
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment