Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
cfd0fc6e
Unverified
Commit
cfd0fc6e
authored
Feb 03, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Feb 03, 2022
Browse files
Merge pull request #76 from aqlaboratory/chunking_experiment_rebased
parents
c9e0f894
2726892a
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
112 additions
and
45 deletions
+112
-45
scripts/utils.py
scripts/utils.py
+5
-5
tests/test_evoformer.py
tests/test_evoformer.py
+9
-6
tests/test_model.py
tests/test_model.py
+1
-0
tests/test_msa.py
tests/test_msa.py
+7
-10
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+3
-3
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+2
-2
train_openfold.py
train_openfold.py
+85
-19
No files found.
scripts/utils.py
View file @
cfd0fc6e
...
...
@@ -4,19 +4,19 @@ from datetime import date
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
'
--
uniref90_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
'
--
mgnify_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
'
--
pdb70_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
'
--
template_mmcif_dir'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'uniclust30_database_path'
,
type
=
str
,
'
--
uniclust30_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
...
...
tests/test_evoformer.py
View file @
cfd0fc6e
...
...
@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
)
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
)
<
consts
.
eps
)
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
)
<
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
...
...
@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
,
msa_dropout
,
pair_stack_dropout
,
blocks_per_
ckpt
=
Non
e
,
ckpt
=
Fals
e
,
inf
=
inf
,
eps
=
eps
,
).
eval
()
...
...
@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_transition
(
model
.
evoformer
.
blocks
[
0
].
core
.
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
print
(
out_gt
)
print
(
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_model.py
View file @
cfd0fc6e
...
...
@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
tests/test_msa.py
View file @
cfd0fc6e
...
...
@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_att_row
(
model
.
evoformer
.
blocks
[
0
].
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
z
=
torch
.
as_tensor
(
pair_act
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
).
cpu
()
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_att_col
(
model
.
evoformer
.
blocks
[
0
].
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
).
cpu
()
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
.
msa_att_col
(
model
.
extra_msa_stack
.
blocks
[
0
].
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_triangular_attention.py
View file @
cfd0fc6e
...
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_att_start
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size
=
None
,
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
cfd0fc6e
...
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
train_openfold.py
View file @
cfd0fc6e
...
...
@@ -13,6 +13,7 @@ import time
import
numpy
as
np
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
import
torch
...
...
@@ -29,7 +30,7 @@ from openfold.utils.callbacks import (
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
scripts.zero_to_fp32
import
(
...
...
@@ -66,8 +67,10 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"loss"
,
loss
)
return
{
"loss"
:
loss
}
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
...
...
@@ -78,9 +81,14 @@ class OpenFoldWrapper(pl.LightningModule):
# Calculate validation loss
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"val_loss"
,
loss
,
prog_bar
=
True
)
return
{
"val_loss"
:
loss
}
loss
=
lddt_ca
(
outputs
[
"final_atom_positions"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
self
.
log
(
"val/loss"
,
loss
,
logger
=
True
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
...
...
@@ -101,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
def
on_load_checkpoint
(
self
,
checkpoint
):
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
def
on_save_checkpoint
(
self
,
checkpoint
):
checkpoint
[
"ema"
]
=
self
.
ema
.
state_dict
()
...
...
@@ -140,15 +151,15 @@ def main(args):
if
(
args
.
checkpoint_best_val
):
checkpoint_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoints"
)
mc
=
ModelCheckpoint
(
dirpath
=
checkpoint_dir
,
filename
=
"openfold_{epoch}_{step}_{val_loss:.2f}"
,
monitor
=
"val_loss"
,
monitor
=
"val/loss"
,
mode
=
"max"
,
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
monitor
=
"val
_
loss"
,
monitor
=
"val
/
loss"
,
min_delta
=
args
.
min_delta
,
patience
=
args
.
patience
,
verbose
=
False
,
...
...
@@ -166,24 +177,41 @@ def main(args):
)
callbacks
.
append
(
perf
)
loggers
=
[]
if
(
args
.
wandb
):
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
id
=
args
.
wandb_id
,
project
=
args
.
wandb_project
,
**
{
"entity"
:
args
.
wandb_entity
}
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
if
"SLURM_JOB_ID"
in
os
.
environ
:
cluster_environment
=
SLURMEnvironment
()
else
:
cluster_environment
=
None
#if "SLURM_JOB_ID" in os.environ:
# cluster_environment = SLURMEnvironment()
#else:
# cluster_environment = None
strategy
=
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
,
cluster_environment
=
cluster_environment
,
#
cluster_environment=cluster_environment,
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
)
>
1
or
args
.
num_nodes
>
1
:
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
strategy
=
None
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callbacks
,
logger
=
loggers
,
)
if
(
args
.
resume_model_weights_only
):
...
...
@@ -198,7 +226,7 @@ def main(args):
)
trainer
.
save_checkpoint
(
os
.
path
.
join
(
trainer
.
logger
.
log
_dir
,
"checkpoints"
,
"final.ckpt"
)
os
.
path
.
join
(
args
.
output
_dir
,
"checkpoints"
,
"final.ckpt"
)
)
...
...
@@ -318,10 +346,37 @@ if __name__ == "__main__":
"--log_performance"
,
type
=
bool_type
,
default
=
False
,
help
=
"Measure performance"
)
parser
.
add_argument
(
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
)
parser
.
add_argument
(
"--experiment_name"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--wandb_id"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--wandb_project"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--wandb_entity"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to TorchScript eligible components of them model"
)
parser
.
add_argument
(
"--train_prot_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--distillation_prot_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
...
...
@@ -330,7 +385,15 @@ if __name__ == "__main__":
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
])
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
,
"--reload_dataloaders_every_n_epochs"
,
]
)
args
=
parser
.
parse_args
()
...
...
@@ -339,4 +402,7 @@ if __name__ == "__main__":
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
main
(
args
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment