Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
fcf9e2f8
Commit
fcf9e2f8
authored
Oct 22, 2021
by
Gustaf Ahdritz
Browse files
Add dummy dataloader, most recent version of training script
parent
6dc8aa7f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
9 deletions
+61
-9
README.md
README.md
+2
-1
openfold/data/data_modules.py
openfold/data/data_modules.py
+22
-0
train_openfold.py
train_openfold.py
+37
-8
No files found.
README.md
View file @
fcf9e2f8
...
@@ -13,7 +13,8 @@ cases where the *Nature* paper differs from the source, we always defer to the
...
@@ -13,7 +13,8 @@ cases where the *Nature* paper differs from the source, we always defer to the
latter.
latter.
OpenFold is built to support inference with AlphaFold's original JAX weights.
OpenFold is built to support inference with AlphaFold's original JAX weights.
Try it out with our
[
Colab notebook
](
https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb
)
.
Try it out with our
[
Colab notebook
](
https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb
)
(not yet visible from Colab because the repo is still private).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with or without
[
DeepSpeed
](
https://github.com/microsoft/deepspeed
)
and with
with or without
[
DeepSpeed
](
https://github.com/microsoft/deepspeed
)
and with
...
...
openfold/data/data_modules.py
View file @
fcf9e2f8
...
@@ -2,6 +2,7 @@ from functools import partial
...
@@ -2,6 +2,7 @@ from functools import partial
import
json
import
json
import
logging
import
logging
import
os
import
os
import
pickle
from
typing
import
Optional
,
Sequence
from
typing
import
Optional
,
Sequence
import
ml_collections
as
mlc
import
ml_collections
as
mlc
...
@@ -446,3 +447,24 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -446,3 +447,24 @@ class OpenFoldDataModule(pl.LightningDataModule):
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"predict"
)
collate_fn
=
self
.
_gen_batch_collator
(
"predict"
)
)
)
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
batch_path
):
with
open
(
batch_path
,
"rb"
)
as
f
:
batch
=
pickle
.
load
(
f
)
def
__getitem__
(
self
,
idx
):
return
copy
.
deepcopy
(
self
.
batch
)
def
__len__
(
self
):
return
1000
class
DummyDataLoader
(
pl
.
LightningDataModule
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dataset
=
Dataset
()
def
train_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
)
train_openfold.py
View file @
fcf9e2f8
...
@@ -2,25 +2,33 @@ import argparse
...
@@ -2,25 +2,33 @@ import argparse
import
logging
import
logging
import
os
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"5"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
import
random
import
random
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.plugins
import
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
import
torch
import
torch
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
(
from
openfold.data.data_modules
import
(
OpenFoldDataModule
,
OpenFoldDataModule
,
DummyDataLoader
,
)
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
copy
class
OpenFoldWrapper
(
pl
.
LightningModule
):
class
OpenFoldWrapper
(
pl
.
LightningModule
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -28,12 +36,17 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -28,12 +36,17 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
config
=
config
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
)
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
# Run the model
# Run the model
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
...
@@ -84,18 +97,29 @@ def main(args):
...
@@ -84,18 +97,29 @@ def main(args):
low_prec
=
(
args
.
precision
==
16
)
low_prec
=
(
args
.
precision
==
16
)
)
)
model_module
=
OpenFoldWrapper
(
config
)
#data_module = DummyDataLoader("batch.pickle")
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
data_module
.
prepare_data
()
data_module
.
setup
()
plugins
=
[]
plugins
=
[]
#plugins.append(DeepSpeedPlugin(config="deepspeed_config.json"))
if
(
args
.
deepspeed_config_path
is
not
None
):
plugins
.
append
(
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
))
#os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
#plugins.append(DDPPlugin(find_unused_parameters=True))
trainer
=
pl
.
Trainer
.
from_argparse_args
(
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
args
,
plugins
=
plugins
,
plugins
=
plugins
,
)
)
model_module
=
OpenFoldWrapper
(
config
)
trainer
.
fit
(
model_module
,
datamodule
=
data_module
)
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
**
vars
(
args
))
trainer
.
fit
(
model_module
,
data_module
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -160,6 +184,10 @@ if __name__ == "__main__":
...
@@ -160,6 +184,10 @@ if __name__ == "__main__":
"--seed"
,
type
=
int
,
default
=
None
,
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"Random seed"
help
=
"Random seed"
)
)
parser
.
add_argument
(
"--deepspeed_config_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
.
set_defaults
(
parser
.
set_defaults
(
...
@@ -172,5 +200,6 @@ if __name__ == "__main__":
...
@@ -172,5 +200,6 @@ if __name__ == "__main__":
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
+
1
)
random
.
seed
(
args
.
seed
+
1
)
np
.
random
.
seed
(
args
.
seed
+
2
)
np
.
random
.
seed
(
args
.
seed
+
2
)
args
.
seed
+=
1
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment