Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
50f72d45
"examples/llm/components/vscode:/vscode.git/clone" did not exist on "19a8a6ec39f820cf2d2cd591a68fb5e668958c7a"
Commit
50f72d45
authored
Feb 10, 2022
by
Gustaf Ahdritz
Browse files
Improve logging
parent
a3e8ebbc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
270 additions
and
42 deletions
+270
-42
openfold/data/data_modules.py
openfold/data/data_modules.py
+25
-25
openfold/utils/loss.py
openfold/utils/loss.py
+9
-2
openfold/utils/superimposition.py
openfold/utils/superimposition.py
+77
-0
openfold/utils/validation_metrics.py
openfold/utils/validation_metrics.py
+38
-0
train_openfold.py
train_openfold.py
+121
-15
No files found.
openfold/data/data_modules.py
View file @
50f72d45
...
...
@@ -213,16 +213,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def
deterministic_train_filter
(
prot
_data_cache_entry
:
Any
,
chain
_data_cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
# Hard filters
resolution
=
prot
_data_cache_entry
.
get
(
"resolution"
,
None
)
resolution
=
chain
_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
):
return
False
seq
=
prot
_data_cache_entry
[
"seq"
]
seq
=
chain
_data_cache_entry
[
"seq"
]
counts
=
{}
for
aa
in
seq
:
counts
.
setdefault
(
aa
,
0
)
...
...
@@ -236,16 +236,16 @@ def deterministic_train_filter(
def
get_stochastic_train_filter_prob
(
prot
_data_cache_entry
:
Any
,
chain
_data_cache_entry
:
Any
,
)
->
List
[
float
]:
# Stochastic filters
probabilities
=
[]
cluster_size
=
prot
_data_cache_entry
.
get
(
"cluster_size"
,
None
)
cluster_size
=
chain
_data_cache_entry
.
get
(
"cluster_size"
,
None
)
if
(
cluster_size
is
not
None
and
cluster_size
>
0
):
probabilities
.
append
(
1
/
cluster_size
)
chain_length
=
len
(
prot
_data_cache_entry
[
"seq"
])
chain_length
=
len
(
chain
_data_cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
# Risk of underflow here?
...
...
@@ -267,7 +267,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
prot
_data_cache_paths
:
List
[
str
],
chain
_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
...
...
@@ -276,10 +276,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
prot
_data_caches
=
[]
for
path
in
prot
_data_cache_paths
:
self
.
chain
_data_caches
=
[]
for
path
in
chain
_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
prot
_data_caches
.
append
(
json
.
load
(
fp
))
self
.
chain
_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
...
...
@@ -298,19 +298,19 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
prot
_data_cache
=
self
.
prot
_data_caches
[
dataset_idx
]
chain
_data_cache
=
self
.
chain
_data_caches
[
dataset_idx
]
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
prot
_data_cache_entry
=
prot
_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
prot
_data_cache_entry
)):
chain
_data_cache_entry
=
chain
_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
chain
_data_cache_entry
)):
continue
p
=
get_stochastic_train_filter_prob
(
prot
_data_cache_entry
,
chain
_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
...
...
@@ -471,10 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_
prot
_data_cache_path
:
Optional
[
str
]
=
None
,
train_
chain
_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_
prot
_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_
chain
_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
...
...
@@ -496,11 +496,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_
prot
_data_cache_path
=
train_
prot
_data_cache_path
self
.
train_
chain
_data_cache_path
=
train_
chain
_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_
prot
_data_cache_path
=
(
distillation_
prot
_data_cache_path
self
.
distillation_
chain
_data_cache_path
=
(
distillation_
chain
_data_cache_path
)
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
...
...
@@ -589,22 +589,22 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
prot
_data_cache_paths
=
[
self
.
train_
prot
_data_cache_path
,
self
.
distillation_
prot
_data_cache_path
,
chain
_data_cache_paths
=
[
self
.
train_
chain
_data_cache_path
,
self
.
distillation_
chain
_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
prot
_data_cache_paths
=
[
self
.
train_
prot
_data_cache_path
,
chain
_data_cache_paths
=
[
self
.
train_
chain
_data_cache_path
,
]
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
prot
_data_cache_paths
=
prot
_data_cache_paths
,
chain
_data_cache_paths
=
chain
_data_cache_paths
,
_roll_at_init
=
False
,
)
...
...
openfold/utils/loss.py
View file @
50f72d45
...
...
@@ -1552,7 +1552,7 @@ class AlphaFoldLoss(nn.Module):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
if
"violation"
not
in
out
.
keys
():
out
[
"violation"
]
=
find_structural_violations
(
batch
,
...
...
@@ -1609,6 +1609,7 @@ class AlphaFoldLoss(nn.Module):
)
cum_loss
=
0.
losses
=
{}
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
loss
=
loss_fn
()
...
...
@@ -1616,6 +1617,9 @@ class AlphaFoldLoss(nn.Module):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
losses
[
loss_name
]
=
loss
.
detach
().
clone
()
losses
[
"unscaled_loss"
]
=
cum_loss
.
detach
().
clone
()
# Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9.
...
...
@@ -1623,4 +1627,7 @@ class AlphaFoldLoss(nn.Module):
crop_len
=
batch
[
"aatype"
].
shape
[
-
1
]
cum_loss
=
cum_loss
*
torch
.
sqrt
(
min
(
seq_len
,
crop_len
))
return
cum_loss
if
(
not
_return_breakdown
):
return
cum_loss
return
cum_loss
,
losses
openfold/utils/superimposition.py
0 → 100644
View file @
50f72d45
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
Bio.SVDSuperimposer
import
SVDSuperimposer
import
numpy
as
np
import
torch
def
_superimpose_np
(
reference
,
coords
):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[N, 3] reference array
coords:
[N, 3] array
Returns:
A tuple of [N, 3] superimposed coords and the final RMSD.
"""
sup
=
SVDSuperimposer
()
sup
.
set
(
reference
,
coords
)
sup
.
run
()
return
sup
.
get_transformed
(),
sup
.
get_rms
()
def
_superimpose_single
(
reference
,
coords
):
reference_np
=
reference
.
detach
().
cpu
().
numpy
()
coords_np
=
coords
.
detach
().
cpu
().
numpy
()
superimposed
,
rmsd
=
_superimpose_np
(
reference_np
,
coords_np
)
return
coords
.
new_tensor
(
superimposed
),
coords
.
new_tensor
(
rmsd
)
def
superimpose
(
reference
,
coords
):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
batch_dims
=
reference
.
shape
[:
-
2
]
flat_reference
=
reference
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_coords
=
coords
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
superimposed_list
=
[]
rmsds
=
[]
for
r
,
c
in
zip
(
flat_reference
,
flat_coords
):
superimposed
,
rmsd
=
_superimpose_single
(
r
,
c
)
superimposed_list
.
append
(
superimposed
)
rmsds
.
append
(
rmsd
)
superimposed_stacked
=
torch
.
stack
(
superimposed_list
,
dim
=
0
)
rmsds_stacked
=
torch
.
stack
(
rmsds
,
dim
=
0
)
superimposed_reshaped
=
superimposed_stacked
.
reshape
(
batch_dims
+
coords
.
shape
[
-
2
:]
)
rmsds_reshaped
=
rmsds_stacked
.
reshape
(
batch_dims
)
return
superimposed_reshaped
,
rmsds_reshaped
openfold/utils/validation_metrics.py
0 → 100644
View file @
50f72d45
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
def
gdt
(
p1
,
p2
,
mask
,
cutoffs
):
n
=
torch
.
sum
(
mask
,
dim
=-
1
)
p1
=
p1
.
float
()
p2
=
p2
.
float
()
distances
=
torch
.
sqrt
(
torch
.
sum
((
p1
-
p2
)
**
2
,
dim
=-
1
))
scores
=
[]
for
c
in
cutoffs
:
score
=
torch
.
sum
((
distances
<=
c
)
*
mask
,
dim
=-
1
)
/
n
scores
.
append
(
score
)
return
sum
(
scores
)
/
len
(
scores
)
def
gdt_ts
(
p1
,
p2
,
mask
):
return
gdt
(
p1
,
p2
,
mask
,
[
1.
,
2.
,
4.
,
8.
])
def
gdt_ha
(
p1
,
p2
,
mask
):
return
gdt
(
p1
,
p2
,
mask
,
[
0.5
,
1.
,
2.
,
4.
])
train_openfold.py
View file @
50f72d45
...
...
@@ -26,14 +26,20 @@ from openfold.data.data_modules import (
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
,
compute_drmsd
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
gdt_ts
,
gdt_ha
,
)
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
)
...
...
@@ -52,6 +58,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
0
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
...
...
@@ -67,43 +74,132 @@ class OpenFoldWrapper(pl.LightningModule):
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
)
# Log it
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
)
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
,
)
self
.
log
(
"train/loss_epoch"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"train/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"train/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
return
loss
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
self
.
cached_weights
=
self
.
model
.
state_dict
()
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
#
Calculate validation loss
#
Run the model
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
lddt_ca_score
=
lddt_ca
(
outputs
[
"final_atom_positions"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
self
.
log
(
"val/lddt_ca"
,
lddt_ca_score
,
logger
=
True
)
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"val/loss"
,
loss
,
logger
=
True
)
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
)
self
.
log
(
"val/loss"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"val/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
True
,
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"val/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
):
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
pred_coords
,
gt_coords
,
all_atom_mask
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
compute_drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
superimposed_pred
,
_
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
)
gdt_ts_score
=
gdt_ts
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
gdt_ha_score
=
gdt_ha
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ta"
]
=
gdt_ha_score
return
metrics
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
...
...
@@ -180,6 +276,10 @@ def main(args):
)
callbacks
.
append
(
perf
)
if
(
args
.
log_lr
):
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
callbacks
.
append
(
lr_monitor
)
loggers
=
[]
if
(
args
.
wandb
):
wdb_logger
=
WandbLogger
(
...
...
@@ -202,7 +302,7 @@ def main(args):
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
strategy
=
None
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
default_root_dir
=
args
.
output_dir
,
...
...
@@ -366,6 +466,12 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--log_lr"
,
action
=
"store_true"
,
default
=
False
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment