Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
ce000c60
Commit
ce000c60
authored
Apr 12, 2024
by
Jennifer Wei
Browse files
changes required for pytorch2
parent
fdd4e1d8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
60 additions
and
27 deletions
+60
-27
environment.yml
environment.yml
+4
-3
openfold/data/data_modules.py
openfold/data/data_modules.py
+4
-4
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+3
-3
openfold/utils/superimposition.py
openfold/utils/superimposition.py
+2
-2
train_openfold.py
train_openfold.py
+47
-15
No files found.
environment.yml
View file @
ce000c60
name
:
openfold-venv
name
:
pytorch1-plupgrade
channels
:
-
conda-forge
-
bioconda
...
...
@@ -11,7 +11,7 @@ dependencies:
-
openmm=7.7
-
pdbfixer
-
cudatoolkit==11.3.*
-
pytorch-lightning==
1.5.10
-
pytorch-lightning==
2.0.9
-
biopython==1.79
-
numpy==1.21
-
pandas==2.0
...
...
@@ -19,7 +19,7 @@ dependencies:
-
requests
-
scipy==1.7
-
tqdm==4.62.2
-
typing-extensions==
3.1
0
-
typing-extensions==
4.
0
-
wandb==0.12.21
-
modelcif==0.7
-
awscli
...
...
@@ -31,6 +31,7 @@ dependencies:
-
bioconda::kalign2==2.04
-
pytorch::pytorch=1.12.*
-
pip
:
-
mpi4py==3.1.5
-
deepspeed==0.12.4
-
dm-tree==0.1.6
-
git+https://github.com/NVIDIA/dllogger.git
...
...
openfold/data/data_modules.py
View file @
ce000c60
...
...
@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with
open
(
distillation_alignment_index_path
,
"r"
)
as
fp
:
self
.
distillation_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
def
setup
(
self
,
stage
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
=
None
):
generator
=
None
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
...
...
@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
val_dataloader
(
self
):
if
self
.
eval_dataset
is
not
None
:
return
self
.
_gen_dataloader
(
"eval"
)
return
None
return
[]
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
...
...
@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
def
setup
(
self
):
def
setup
(
self
,
setup
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
openfold/data/data_pipeline.py
View file @
ce000c60
...
...
@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features
[
"num_alignments"
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
np
.
object
_
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
object
)
return
features
...
...
@@ -590,7 +590,7 @@ def convert_monomer_features(
)
->
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object
_
)
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
object
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
...
...
@@ -1296,7 +1296,7 @@ class DataPipelineMultimer:
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
np
.
object
_
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
object
)
mmcif_feats
[
"is_distillation"
]
=
np
.
array
(
0.
,
dtype
=
np
.
float32
)
...
...
openfold/utils/superimposition.py
View file @
ce000c60
...
...
@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords):
def
_superimpose_single
(
reference
,
coords
):
reference_np
=
reference
.
detach
().
cpu
().
numpy
()
coords_np
=
coords
.
detach
().
cpu
().
numpy
()
reference_np
=
reference
.
detach
().
to
(
torch
.
float
).
cpu
().
numpy
()
coords_np
=
coords
.
detach
().
to
(
torch
.
float
).
cpu
().
numpy
()
superimposed
,
rmsd
=
_superimpose_np
(
reference_np
,
coords_np
)
return
coords
.
new_tensor
(
superimposed
),
coords
.
new_tensor
(
rmsd
)
...
...
train_openfold.py
View file @
ce000c60
...
...
@@ -8,8 +8,11 @@ import pytorch_lightning as pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.strategies
import
DDPStrategy
,
DeepSpeedStrategy
from
pytorch_lightning.plugins.environments
import
MPIEnvironment
from
pytorch_lightning
import
seed_everything
import
torch
import
wandb
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
OpenFoldDataModule
,
OpenFoldMultimerDataModule
...
...
@@ -24,7 +27,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
...
...
@@ -59,7 +61,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
self
.
save_hyperparameters
self
.
save_hyperparameters
()
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
...
...
@@ -70,14 +72,15 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
prog_bar
=
(
loss_name
==
'loss'
),
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
sync_dist
=
False
,
)
if
(
train
):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
sync_dist
=
False
,
)
with
torch
.
no_grad
():
...
...
@@ -91,7 +94,8 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
prog_bar
=
(
k
==
'loss'
),
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
sync_dist
=
False
,
)
def
training_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -154,7 +158,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
on_
validation_epoch_end
(
self
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
...
...
@@ -377,40 +381,59 @@ def main(args):
callbacks
.
append
(
lr_monitor
)
loggers
=
[]
is_rank_zero
=
int
(
os
.
environ
.
get
(
"PMI_RANK"
))
==
0
if
(
args
.
wandb
):
if
args
.
mpi_plugin
and
is_rank_zero
:
wandb_init_dict
=
dict
(
name
=
args
.
experiment_name
,
project
=
args
.
wandb_project
,
id
=
args
.
wandb_id
,
dir
=
args
.
output_dir
,
resume
=
"allow"
,
anonymous
=
None
,
entity
=
args
.
wandb_entity
)
wandb
.
run
=
wandb
.
init
(
**
wandb_init_dict
)
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
id
=
args
.
wandb_id
,
project
=
args
.
wandb_project
,
config
=
config
.
to_dict
(),
**
{
"entity"
:
args
.
wandb_entity
}
)
loggers
.
append
(
wdb_logger
)
cluster_environment
=
MPIEnvironment
()
if
args
.
mpi_plugin
else
None
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeed
Plugin
(
strategy
=
DeepSpeed
Strategy
(
config
=
args
.
deepspeed_config_path
,
cluster_environment
=
cluster_environment
,
)
if
(
args
.
wandb
):
if
(
args
.
wandb
and
is_rank_zero
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
strategy
=
DDPStrategy
(
find_unused_parameters
=
False
,
cluster_environment
=
cluster_environment
)
else
:
strategy
=
None
if
(
args
.
wandb
):
if
(
args
.
wandb
and
is_rank_zero
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
trainer
=
pl
.
Trainer
(
num_nodes
=
args
.
num_nodes
,
devices
=
args
.
gpus
,
precision
=
args
.
precision
,
max_epochs
=
args
.
max_epochs
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callbacks
,
logger
=
loggers
,
profiler
=
'simple'
,
)
if
(
args
.
resume_model_weights_only
):
...
...
@@ -621,7 +644,16 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"32"
)
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
)
# parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser
.
set_defaults
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment