Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bbf989a7
Commit
bbf989a7
authored
Apr 16, 2024
by
Jennifer
Browse files
psivant local pl2 upgrades without mpi configuration
parents
d8418293
ce000c60
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
29 deletions
+63
-29
environment.yml
environment.yml
+4
-3
openfold/data/data_modules.py
openfold/data/data_modules.py
+4
-4
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+3
-3
openfold/utils/superimposition.py
openfold/utils/superimposition.py
+2
-2
train_openfold.py
train_openfold.py
+50
-17
No files found.
environment.yml
View file @
bbf989a7
name
:
openfold-venv
name
:
pytorch1-plupgrade
channels
:
-
conda-forge
-
bioconda
...
...
@@ -11,7 +11,7 @@ dependencies:
-
openmm=7.7
-
pdbfixer
-
cudatoolkit==11.3.*
-
pytorch-lightning==
1.5.10
-
pytorch-lightning==
2.0.9
-
biopython==1.79
-
numpy==1.21
-
pandas==2.0
...
...
@@ -19,11 +19,12 @@ dependencies:
-
requests
-
scipy==1.7
-
tqdm==4.62.2
-
typing-extensions==
3.1
0
-
typing-extensions==
4.
0
-
wandb==0.12.21
-
modelcif==0.7
-
awscli
-
ml-collections
-
mkl==2024.0.0
-
aria2
-
git
-
bioconda::hmmer==3.3.2
...
...
openfold/data/data_modules.py
View file @
bbf989a7
...
...
@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with
open
(
distillation_alignment_index_path
,
"r"
)
as
fp
:
self
.
distillation_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
def
setup
(
self
,
stage
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
=
None
):
generator
=
None
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
...
...
@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
val_dataloader
(
self
):
if
self
.
eval_dataset
is
not
None
:
return
self
.
_gen_dataloader
(
"eval"
)
return
None
return
[]
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
...
...
@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
def
setup
(
self
):
def
setup
(
self
,
setup
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
openfold/data/data_pipeline.py
View file @
bbf989a7
...
...
@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features
[
"num_alignments"
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
np
.
object
_
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
object
)
return
features
...
...
@@ -590,7 +590,7 @@ def convert_monomer_features(
)
->
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object
_
)
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
object
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
...
...
@@ -1296,7 +1296,7 @@ class DataPipelineMultimer:
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
np
.
object
_
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
object
)
mmcif_feats
[
"is_distillation"
]
=
np
.
array
(
0.
,
dtype
=
np
.
float32
)
...
...
openfold/utils/superimposition.py
View file @
bbf989a7
...
...
@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords):
def
_superimpose_single
(
reference
,
coords
):
reference_np
=
reference
.
detach
().
cpu
().
numpy
()
coords_np
=
coords
.
detach
().
cpu
().
numpy
()
reference_np
=
reference
.
detach
().
to
(
torch
.
float
).
cpu
().
numpy
()
coords_np
=
coords
.
detach
().
to
(
torch
.
float
).
cpu
().
numpy
()
superimposed
,
rmsd
=
_superimpose_np
(
reference_np
,
coords_np
)
return
coords
.
new_tensor
(
superimposed
),
coords
.
new_tensor
(
rmsd
)
...
...
train_openfold.py
View file @
bbf989a7
...
...
@@ -9,8 +9,11 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from
pytorch_lightning.callbacks
import
DeviceStatsMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.strategies
import
DDPStrategy
,
DeepSpeedStrategy
from
pytorch_lightning.plugins.environments
import
MPIEnvironment
from
pytorch_lightning
import
seed_everything
import
torch
import
wandb
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
OpenFoldDataModule
,
OpenFoldMultimerDataModule
...
...
@@ -25,7 +28,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
...
...
@@ -60,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
self
.
save_hyperparameters
self
.
save_hyperparameters
()
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
...
...
@@ -71,14 +73,17 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
prog_bar
=
(
loss_name
==
'loss'
),
# on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
sync_dist
=
True
,
)
if
(
train
):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
# on_step=False, on_epoch=True, logger=True, sync_dist=False,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
sync_dist
=
True
,
)
with
torch
.
no_grad
():
...
...
@@ -92,7 +97,9 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
prog_bar
=
(
k
==
'loss'
),
# on_step=False, on_epoch=True, logger=True, sync_dist=False,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
sync_dist
=
True
,
)
def
training_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -155,7 +162,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
on_
validation_epoch_end
(
self
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
...
...
@@ -378,41 +385,58 @@ def main(args):
callbacks
.
append
(
lr_monitor
)
loggers
=
[]
is_rank_zero
=
args
.
mpi_plugin
and
(
int
(
os
.
environ
.
get
(
"PMI_RANK"
))
==
0
)
if
(
args
.
wandb
):
if
args
.
mpi_plugin
and
is_rank_zero
:
wandb_init_dict
=
dict
(
name
=
args
.
experiment_name
,
project
=
args
.
wandb_project
,
id
=
args
.
wandb_id
,
dir
=
args
.
output_dir
,
resume
=
"allow"
,
anonymous
=
None
,
entity
=
args
.
wandb_entity
)
wandb
.
run
=
wandb
.
init
(
**
wandb_init_dict
)
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
id
=
args
.
wandb_id
,
project
=
args
.
wandb_project
,
config
=
config
.
to_dict
(),
**
{
"entity"
:
args
.
wandb_entity
}
)
loggers
.
append
(
wdb_logger
)
cluster_environment
=
MPIEnvironment
()
if
args
.
mpi_plugin
else
None
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeed
Plugin
(
strategy
=
DeepSpeed
Strategy
(
config
=
args
.
deepspeed_config_path
,
cluster_environment
=
cluster_environment
,
)
if
(
args
.
wandb
):
if
(
args
.
wandb
and
is_rank_zero
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
strategy
=
DDPStrategy
(
find_unused_parameters
=
False
,
cluster_environment
=
cluster_environment
)
else
:
strategy
=
None
if
(
args
.
wandb
):
if
(
args
.
wandb
and
is_rank_zero
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
trainer
=
pl
.
Trainer
(
num_nodes
=
args
.
num_nodes
,
devices
=
args
.
gpus
,
precision
=
args
.
precision
,
max_epochs
=
args
.
max_epochs
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callbacks
,
logger
=
loggers
,
profiler
=
'simple'
,
)
if
(
args
.
resume_model_weights_only
):
...
...
@@ -623,8 +647,17 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"32"
)
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
)
# parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser
.
set_defaults
(
num_sanity_val_steps
=
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment