Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
df4dfacb
Commit
df4dfacb
authored
Jan 24, 2024
by
Jennifer
Browse files
first pass changes to run with pl 2.1
parent
e813bb53
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
184 additions
and
137 deletions
+184
-137
openfold/data/data_modules.py
openfold/data/data_modules.py
+5
-4
openfold/utils/seed.py
openfold/utils/seed.py
+1
-1
tests/compare_utils.py
tests/compare_utils.py
+18
-0
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+2
-4
tests/test_evoformer.py
tests/test_evoformer.py
+5
-6
tests/test_feats.py
tests/test_feats.py
+1
-1
tests/test_msa.py
tests/test_msa.py
+3
-3
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_structure_module.py
tests/test_structure_module.py
+2
-2
tests/test_template.py
tests/test_template.py
+2
-4
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+1
-1
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+1
-1
train_openfold.py
train_openfold.py
+142
-109
No files found.
openfold/data/data_modules.py
View file @
df4dfacb
...
...
@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with
open
(
distillation_alignment_index_path
,
"r"
)
as
fp
:
self
.
distillation_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
def
setup
(
self
,
stage
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
=
None
):
generator
=
None
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
...
...
@@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
val_dataloader
(
self
):
if
self
.
eval_dataset
is
not
None
:
return
self
.
_gen_dataloader
(
"eval"
)
return
None
# Temp fix to pass the validation step
return
[]
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
...
...
@@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
def
setup
(
self
):
def
setup
(
self
,
setup
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
openfold/utils/seed.py
View file @
df4dfacb
...
...
@@ -2,7 +2,7 @@ import os
import
logging
import
random
import
numpy
as
np
from
pytorch_lightning
.utilities.seed
import
seed_everything
from
pytorch_lightning
import
seed_everything
from
openfold.utils.suppress_output
import
SuppressLogging
...
...
tests/compare_utils.py
View file @
df4dfacb
...
...
@@ -6,6 +6,7 @@ import sys
import
unittest
import
numpy
as
np
import
torch
from
openfold.config
import
model_config
from
openfold.model.model
import
AlphaFold
...
...
@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path):
"Make sure to call import_alphafold before running this function"
)
return
params
def
_assert_abs_diff_small_base
(
compare_func
,
expected
,
actual
,
eps
):
# Helper function for comparing absolute differences of two torch tensors.
abs_diff
=
torch
.
abs
(
expected
-
actual
)
err
=
compare_func
(
abs_diff
)
zero_tensor
=
torch
.
tensor
(
0
,
dtype
=
err
.
dtype
)
rtol
=
1.6e-2
if
err
.
dtype
==
torch
.
bfloat16
else
1.3e-6
torch
.
testing
.
assert_close
(
err
,
zero_tensor
,
atol
=
eps
,
rtol
=
rtol
)
def
assert_max_abs_diff_small
(
expected
,
actual
,
eps
):
_assert_abs_diff_small_base
(
torch
.
max
,
expected
,
actual
,
eps
)
def
assert_mean_abs_diff_small
(
expected
,
actual
,
eps
):
_assert_abs_diff_small_base
(
torch
.
mean
,
expected
,
actual
,
eps
)
tests/test_deepspeed_evo_attention.py
View file @
df4dfacb
...
...
@@ -276,8 +276,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
)
out_repro_ds
=
out_repro_ds
[
"template_pair_embedding"
].
cpu
()
err
=
torch
.
max
(
torch
.
abs
(
out_repro
-
out_repro_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'Error
{
err
}
'
)
compare_utils
.
assert_max_abs_diff_small
(
out_repro
,
out_repro_ds
,
eps
)
def
test_compare_model
(
self
):
"""
...
...
@@ -335,8 +334,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
].
squeeze
(
0
)
out_repro_ds
=
out_repro_ds
[
"sm"
][
"positions"
][
-
1
].
squeeze
(
0
)
err
=
torch
.
mean
(
torch
.
abs
(
out_repro
-
out_repro_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'Error:
{
err
}
'
)
compare_utils
.
assert_mean_abs_diff_small
(
out_repro
,
out_repro_ds
,
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_evoformer.py
View file @
df4dfacb
...
...
@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_repro
_msa
-
out_
gt_msa
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro
_pair
-
out_
gt
_pair
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
_msa
,
out_
repro_msa
,
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
_pair
,
out_
repro
_pair
,
consts
.
eps
)
# Inplace version
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
...
...
@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_repro
_msa
-
out_
gt_msa
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro
_pair
-
out_
gt
_pair
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
_msa
,
out_
repro_msa
,
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
_pair
,
out_
repro
_pair
,
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
...
...
@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase):
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_feats.py
View file @
df4dfacb
...
...
@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch
.
tensor
(
restype_atom14_rigid_group_positions
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_msa.py
View file @
df4dfacb
...
...
@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
...
...
@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
...
...
@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_outer_product_mean.py
View file @
df4dfacb
...
...
@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
5e-4
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
5e-4
)
if
__name__
==
"__main__"
:
...
...
tests/test_structure_module.py
View file @
df4dfacb
...
...
@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
0.05
)
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
...
...
@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestAngleResnet
(
unittest
.
TestCase
):
...
...
tests/test_template.py
View file @
df4dfacb
...
...
@@ -191,9 +191,7 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans
=
False
,
).
cpu
()
diff
=
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
self
.
assertTrue
(
diff
<
consts
.
eps
,
msg
=
f
"Found difference between ground truth and reproduction of
{
diff
}
"
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
Template
(
unittest
.
TestCase
):
...
...
@@ -286,7 +284,7 @@ class Template(unittest.TestCase):
out_repro
=
out_repro_all
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_triangular_attention.py
View file @
df4dfacb
...
...
@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size
=
None
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
df4dfacb
...
...
@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
inplace_safe
=
True
,
_inplace_chunk_size
=
4
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
...
...
train_openfold.py
View file @
df4dfacb
...
...
@@ -7,7 +7,7 @@ import pytorch_lightning as pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.
plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.
strategies
import
DeepSpeedStrategy
,
DDPStrategy
import
torch
from
openfold.config
import
model_config
...
...
@@ -55,7 +55,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
...
...
@@ -66,12 +66,12 @@ class OpenFoldWrapper(pl.LightningModule):
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
if
(
train
):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
...
...
@@ -80,12 +80,12 @@ class OpenFoldWrapper(pl.LightningModule):
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
batch
,
outputs
,
superimposition_metrics
=
(
not
train
)
)
for
k
,
v
in
other_metrics
.
items
():
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
...
...
@@ -93,7 +93,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
...
@@ -124,12 +124,13 @@ class OpenFoldWrapper(pl.LightningModule):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
def
clone_param
(
t
):
return
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
...
@@ -151,23 +152,23 @@ class OpenFoldWrapper(pl.LightningModule):
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
on_
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
_compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
):
def
_compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
):
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
...
...
@@ -175,7 +176,7 @@ class OpenFoldWrapper(pl.LightningModule):
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
pred_coords
,
gt_coords
,
...
...
@@ -183,18 +184,18 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
if
(
superimposition_metrics
):
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
...
...
@@ -208,22 +209,22 @@ class OpenFoldWrapper(pl.LightningModule):
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
return
metrics
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
learning_rate
,
self
.
model
.
parameters
(),
lr
=
learning_rate
,
eps
=
eps
)
...
...
@@ -247,8 +248,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_load_checkpoint
(
self
,
checkpoint
):
ema
=
checkpoint
[
"ema"
]
if
(
not
self
.
model
.
template_config
.
enabled
):
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
if
(
not
self
.
model
.
template_config
.
enabled
):
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
self
.
ema
.
load_state_dict
(
ema
)
def
on_save_checkpoint
(
self
,
checkpoint
):
...
...
@@ -259,69 +261,72 @@ class OpenFoldWrapper(pl.LightningModule):
def
load_from_jax
(
self
,
jax_path
):
model_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
)
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
)
)[
0
]
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
import_jax_weights_
(
self
.
model
,
jax_path
,
version
=
model_version
self
.
model
,
jax_path
,
version
=
model_version
)
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
config
=
model_config
(
args
.
config_preset
,
train
=
True
,
args
.
config_preset
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
last_global_step
=
int
(
sd
[
'global_step'
])
model_module
.
resume_last_lr_step
(
last_global_step
)
logging
.
info
(
"Successfully loaded last lr step..."
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
# TorchScript components of the model
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
else
:
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
data_module
.
prepare_data
()
data_module
.
setup
()
callbacks
=
[]
if
(
args
.
checkpoint_every_epoch
):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
...
...
@@ -329,7 +334,7 @@ def main(args):
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
monitor
=
"val/lddt_ca"
,
min_delta
=
args
.
min_delta
,
...
...
@@ -341,7 +346,7 @@ def main(args):
)
callbacks
.
append
(
es
)
if
(
args
.
log_performance
):
if
(
args
.
log_performance
):
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
perf
=
PerformanceLoggingCallback
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"performance_log.json"
),
...
...
@@ -349,12 +354,12 @@ def main(args):
)
callbacks
.
append
(
perf
)
if
(
args
.
log_lr
):
if
(
args
.
log_lr
):
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
callbacks
.
append
(
lr_monitor
)
loggers
=
[]
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
...
...
@@ -364,38 +369,43 @@ def main(args):
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeed
Plugin
(
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeed
Strategy
(
config
=
args
.
deepspeed_config_path
,
)
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDP
Plugin
(
find_unused_parameters
=
False
)
strategy
=
DDP
Strategy
(
find_unused_parameters
=
False
)
else
:
strategy
=
None
if
(
args
.
wandb
):
if
(
args
.
wandb
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callbacks
,
logger
=
loggers
,
)
if
(
args
.
resume_model_weights_only
):
# Raw dump of all args from pl.Trainer constructor
trainer_kws
=
set
([
'accelerator'
,
'strategy'
,
'devices'
,
'num_nodes'
,
'precision'
,
'logger'
,
'callbacks'
,
'fast_dev_run'
,
'max_epochs'
,
'min_epochs'
,
'max_steps'
,
'min_steps'
,
'max_tim'
,
'limit_train_batches'
,
'limit_val_batches'
,
'limit_test_batches'
,
'limit_predict_batches'
,
'overfit_batches'
,
'val_check_interval'
,
'check_val_every_n_epoch'
,
'num_sanity_val_steps'
,
'log_every_n_steps'
,
'enable_checkpointing'
,
'enable_progress_bar'
,
'enable_model_summary'
,
'accumulate_grad_batches'
,
'gradient_clip_val'
,
'gradient_clip_algorithm'
,
'deterministic'
,
'benchmark'
,
'inference_mode'
,
'use_distributed_sampler'
,
'profiler'
,
'detect_anomaly'
,
'barebones'
,
'plugins'
,
'sync_batchnorm'
,
'reload_dataloaders_every_n_epochs'
,
'default_root_dir'
,
])
trainer_args
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
k
in
trainer_kws
}
trainer_args
.
update
({
'default_root_dir'
:
args
.
output_dir
,
'strategy'
:
strategy
,
'callbacks'
:
callbacks
,
'logger'
:
loggers
,
})
trainer
=
pl
.
Trainer
(
**
trainer_args
)
if
(
args
.
resume_model_weights_only
):
ckpt_path
=
None
else
:
ckpt_path
=
args
.
resume_from_ckpt
trainer
.
fit
(
model_module
,
model_module
,
datamodule
=
data_module
,
ckpt_path
=
ckpt_path
,
)
...
...
@@ -594,36 +604,59 @@ if __name__ == "__main__":
"--distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
help
=
"Distillation alignment index. See the README for instructions."
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
parser
.
set_defaults
(
num_sanity_val_steps
=
0
,
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--replace_sampler_ddp"
,
type
=
bool_type
,
default
=
True
,
)
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
25
,
)
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
,
)
# parser = pl.Trainer.add_argparse_args(parser)
#
# # Disable the initial validation pass
# parser.set_defaults(
# num_sanity_val_steps=0,
# )
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
,
"--reload_dataloaders_every_n_epochs"
,
]
)
#
#
Remove some buggy/redundant arguments introduced by the Trainer
#
remove_arguments(
#
parser,
#
[
#
"--accelerator",
#
"--resume_from_checkpoint",
#
"--reload_dataloaders_every_epoch",
#
"--reload_dataloaders_every_n_epochs",
#
]
#
)
args
=
parser
.
parse_args
()
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment