Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
0ca66146
Unverified
Commit
0ca66146
authored
Sep 29, 2023
by
Christina Floristean
Committed by
GitHub
Sep 29, 2023
Browse files
Merge pull request #353 from dingquanyu/permutation
Update multi-chain permutation and training codes
parents
8820875b
a9d65037
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
258 additions
and
178 deletions
+258
-178
openfold/data/data_modules.py
openfold/data/data_modules.py
+11
-10
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+22
-8
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+35
-33
openfold/utils/loss.py
openfold/utils/loss.py
+173
-113
tests/test_permutation.py
tests/test_permutation.py
+5
-4
train_openfold.py
train_openfold.py
+12
-10
No files found.
openfold/data/data_modules.py
View file @
0ca66146
...
@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import (
...
@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
tensor_tree_map
,
)
)
import
random
import
random
logging
.
basicConfig
(
level
=
logging
.
INFO
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
temp_fasta_file
(
sequence_str
):
def
temp_fasta_file
(
sequence_str
):
...
@@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
)
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
with
open
(
path
,
'r'
)
as
f
:
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
mmcif_string
=
f
.
read
()
...
@@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif
=
mmcif_object
,
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
alignment_index
=
alignment_index
alignment_index
=
alignment_index
)
)
return
data
return
data
...
@@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
idx_to_mmcif_id
(
self
,
idx
):
def
idx_to_mmcif_id
(
self
,
idx
):
return
self
.
_mmcifs
[
idx
]
return
self
.
_mmcifs
[
idx
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
alignment_index
=
None
alignment_index
=
None
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
ext
=
None
ext
=
None
...
@@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
if
(
self
.
_output_raw
):
if
(
self
.
_output_raw
):
return
data
return
data
# process all_chain_features
# process all_chain_features
data
=
self
.
feature_pipeline
.
process_features
(
data
,
data
,
ground_truth
=
self
.
feature_pipeline
.
process_features
(
data
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
is_multimer
=
True
)
is_multimer
=
True
)
# if it's inference mode, only need all_chain_features
# if it's inference mode, only need all_chain_features
data
[
"batch_idx"
]
=
torch
.
tensor
(
data
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
data
[
"aatype"
].
device
)
device
=
data
[
"aatype"
].
device
)
return
data
return
data
,
ground_truth
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_chain_ids
)
return
len
(
self
.
_chain_ids
)
...
@@ -723,9 +724,9 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
...
@@ -723,9 +724,9 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
i
)
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
i
)
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
,
max_resolution
=
9
):
minimum_number_of_residues
=
5
):
selected_idx
.
append
(
i
)
selected_idx
.
append
(
i
)
logging
.
info
(
f
"Originally
{
len
(
mmcif_data_cache
)
}
mmcifs. After filtering:
{
len
(
selected_idx
)
}
"
)
else
:
else
:
selected_idx
=
list
(
range
(
len
(
dataset
.
_mmcif_id_to_idx_dict
)))
selected_idx
=
list
(
range
(
len
(
dataset
.
_mmcif_id_to_idx_dict
)))
return
selected_idx
return
selected_idx
...
...
openfold/data/feature_pipeline.py
View file @
0ca66146
...
@@ -81,7 +81,7 @@ def np_example_to_features(
...
@@ -81,7 +81,7 @@ def np_example_to_features(
seq_length
=
np_example
[
"seq_length"
]
seq_length
=
np_example
[
"seq_length"
]
num_res
=
int
(
seq_length
[
0
])
if
seq_length
.
ndim
!=
0
else
int
(
seq_length
)
num_res
=
int
(
seq_length
[
0
])
if
seq_length
.
ndim
!=
0
else
int
(
seq_length
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
if
"deletion_matrix_int"
in
np_example
:
if
"deletion_matrix_int"
in
np_example
:
np_example
[
"deletion_matrix"
]
=
np_example
.
pop
(
np_example
[
"deletion_matrix"
]
=
np_example
.
pop
(
"deletion_matrix_int"
"deletion_matrix_int"
...
@@ -90,15 +90,29 @@ def np_example_to_features(
...
@@ -90,15 +90,29 @@ def np_example_to_features(
tensor_dict
=
np_to_tensor_dict
(
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
np_example
=
np_example
,
features
=
feature_names
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
(
not
is_multimer
):
if
is_multimer
:
features
=
input_pipeline
.
process_tensors_from_config
(
if
mode
==
'train'
:
tensor_dict
,
features
,
gt_features
=
input_pipeline_multimer
.
process_tensors_from_config
(
cfg
.
common
,
tensor_dict
,
cfg
[
mode
],
cfg
.
common
,
)
cfg
[
mode
],
is_training
=
True
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()},
gt_features
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
is_training
=
False
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
else
:
else
:
features
=
input_pipeline
_multimer
.
process_tensors_from_config
(
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
tensor_dict
,
cfg
.
common
,
cfg
.
common
,
cfg
[
mode
],
cfg
[
mode
],
...
...
openfold/data/input_pipeline_multimer.py
View file @
0ca66146
...
@@ -21,19 +21,8 @@ from openfold.data import (
...
@@ -21,19 +21,8 @@ from openfold.data import (
data_transforms_multimer
,
data_transforms_multimer
,
)
)
def
grountruth_transforms_fns
():
def
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
):
transforms
=
[
data_transforms
.
make_atom14_masks
,
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
if
mode_cfg
.
supervised
:
transforms
.
extend
(
[
data_transforms
.
make_atom14_positions
,
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
""
),
data_transforms
.
atom37_to_torsion_angles
(
""
),
...
@@ -41,7 +30,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
...
@@ -41,7 +30,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
data_transforms
.
get_chi_angles
,
]
]
)
return
transforms
def
nonensembled_transform_fns
():
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
]
return
transforms
return
transforms
...
@@ -114,11 +112,29 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
...
@@ -114,11 +112,29 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return
transforms
return
transforms
def
prepare_ground_truth_features
(
tensors
):
"""Prepare ground truth features that are only needed for loss calculation during training"""
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
GROUNDTRUTH_FEATURES
=
[
'all_atom_mask'
,
'all_atom_positions'
,
'asym_id'
,
'sym_id'
,
'entity_id'
]
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
GROUNDTRUTH_FEATURES
}
gt_tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
gt_tensors
=
compose
(
grountruth_transforms_fns
())(
gt_tensors
)
return
gt_tensors
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
is_training
=
False
):
"""Based on the config, apply filters and transformations to the data."""
"""Based on the config, apply filters and transformations to the data."""
if
is_training
:
gt_tensors
=
prepare_ground_truth_features
(
tensors
)
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
nonensembled
=
nonensembled_transform_fns
()
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
def
wrap_ensemble_fn
(
data
,
i
):
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
"""Function to be mapped over the ensemble dimension."""
...
@@ -132,28 +148,14 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
...
@@ -132,28 +148,14 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d
[
"ensemble_index"
]
=
i
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
return
fn
(
d
)
no_templates
=
True
if
(
"template_aatype"
in
tensors
):
no_templates
=
tensors
[
"template_aatype"
].
shape
[
0
]
==
0
nonensembled
=
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
,
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
tensors
=
map_fn
(
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
)
return
tensors
if
is_training
:
return
tensors
,
gt_tensors
else
:
return
tensors
@
data_transforms
.
curry1
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
def
compose
(
x
,
fs
):
...
...
openfold/utils/loss.py
View file @
0ca66146
...
@@ -1700,9 +1700,6 @@ def compute_rmsd(
...
@@ -1700,9 +1700,6 @@ def compute_rmsd(
atom_mask
:
torch
.
Tensor
=
None
,
atom_mask
:
torch
.
Tensor
=
None
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# shape check
true_atom_pos
=
true_atom_pos
pred_atom_pos
=
pred_atom_pos
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
del
true_atom_pos
del
true_atom_pos
del
pred_atom_pos
del
pred_atom_pos
...
@@ -1784,20 +1781,23 @@ def get_optimal_transform(
...
@@ -1784,20 +1781,23 @@ def get_optimal_transform(
return
r
,
x
return
r
,
x
def
get_least_asym_entity_or_longest_length
(
batch
):
def
get_least_asym_entity_or_longest_length
(
batch
,
input_asym_id
):
"""
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
then choose one of the corresponding subunits as anchor
"""
REQUIRED_FEATURES
=
[
'entity_id'
,
'asym_id'
]
seq_length
=
batch
[
'seq_length'
].
item
()
# remove padding part before selecting candidate
Args:
remove_padding
=
lambda
t
:
torch
.
index_select
(
t
,
dim
=
1
,
index
=
torch
.
arange
(
seq_length
,
device
=
t
.
device
))
batch: in this funtion batch is the full ground truth features
batch
=
{
k
:
tensor_tree_map
(
remove_padding
,
batch
[
k
])
for
k
in
REQUIRED_FEATURES
}
input_asym_id: A list of aym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_asym_count
=
{}
entity_length
=
{}
entity_length
=
{}
...
@@ -1822,19 +1822,15 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1822,19 +1822,15 @@ def get_least_asym_entity_or_longest_length(batch):
if
len
(
least_asym_entities
)
>
1
:
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
least_asym_entities
=
least_asym_entities
[
0
]
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if
len
(
best_pred_asym
)
>
1
:
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
return
least_asym_entities
[
0
],
best_pred_asym
anchor_gt_asym_id
=
random
.
choice
(
entity_2_asym_list
[
least_asym_entities
])
anchor_pred_asym_ids
=
[
id
for
id
in
entity_2_asym_list
[
least_asym_entities
]
if
id
in
input_asym_id
]
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
def
greedy_align
(
def
greedy_align
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -1847,6 +1843,7 @@ def greedy_align(
...
@@ -1847,6 +1843,7 @@ def greedy_align(
"""
"""
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
i
=
int
(
cur_asym_id
-
1
)
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
...
@@ -1884,9 +1881,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
...
@@ -1884,9 +1881,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
"""
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
Merge ground truth labels according to the permutation results
labels: list of original ground truth feats
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
align: list of tuples, each entry specify the corresponding label of the asym.
...
@@ -1898,15 +1896,12 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres):
...
@@ -1898,15 +1896,12 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres):
cur_out
=
{}
cur_out
=
{}
for
i
,
j
in
align
:
for
i
,
j
in
align
:
label
=
labels
[
j
][
k
]
label
=
labels
[
j
][
k
]
cur_num_res
=
labels
[
j
][
'aatype'
].
shape
[
-
1
]
# to 1-based
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
continue
else
:
else
:
dimension_to_merge
=
label
.
shape
.
index
(
cur_num_res
)
if
cur_num_res
in
label
.
shape
else
0
dimension_to_merge
=
1
if
k
==
'all_atom_positions'
:
dimension_to_merge
=
1
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
if
len
(
cur_out
)
>
0
:
...
@@ -2037,19 +2032,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2037,19 +2032,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
self
.
config
=
config
self
.
config
=
config
@
staticmethod
def
determine_split_dim
(
batch
)
->
dict
:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim
=
batch
[
'aatype'
].
shape
[
-
1
]
dim_dict
=
{
k
:
list
(
v
.
shape
).
index
(
padded_dim
)
for
k
,
v
in
batch
.
items
()
if
padded_dim
in
v
.
shape
}
return
dim_dict
@
staticmethod
@
staticmethod
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
,
dim_dict
):
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
,
split_dim
=
1
):
"""
"""
Splits ground truth features according to chains
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
required to finish multi-chain permutation
"""
"""
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
...
@@ -2061,11 +2051,85 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2061,11 +2051,85 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
unique_asym_ids
.
append
(
padding_asym_id
)
unique_asym_ids
.
append
(
padding_asym_id
)
asym_id_counts
.
append
(
padding_asym_counts
)
asym_id_counts
.
append
(
padding_asym_counts
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
]
)]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
split_dim
)]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
return
labels
@
staticmethod
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
):
def
get_per_asym_residue_index
(
features
):
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
features
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
features
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
features
[
"residue_index"
],
asym_mask
)
return
per_asym_residue_index
@
staticmethod
def
get_entity_2_asym_list
(
batch
):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list
=
{}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
return
entity_2_asym_list
@
staticmethod
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
@
staticmethod
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
):
input_mask
=
AlphaFoldMultimerLoss
.
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
)
input_mask
=
torch
.
squeeze
(
input_mask
,
0
)
pred_ca_pos
=
torch
.
squeeze
(
pred_ca_pos
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
torch
.
squeeze
(
anchor_true_pos
,
0
),
mask
=
input_mask
)
return
r
,
x
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
permutate_chains
=
False
):
"""
"""
A class method that first permutate chains in ground truth first
A class method that first permutate chains in ground truth first
before calculating the loss.
before calculating the loss.
...
@@ -2073,80 +2137,73 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2073,80 +2137,73 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
"""
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
feature
,
ground_truth
=
batch
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
del
batch
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
best_rmsd
=
float
(
'inf'
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
best_align
=
None
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
del
ground_truth
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
# Then calculate optimal transform by aligning anchors
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
ca_idx
=
rc
.
atom_order
[
"CA"
]
entity_2_asym_list
=
{}
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
for
cur_ent_id
in
unique_entity_ids
:
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
true_ca_poses
=
[
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
]
# list([nres, 3])
anchor_residue_idx
=
per_asym_residue_index
[
int
(
anchor_pred_asym
)]
true_ca_masks
=
[
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
]
# list([nres,])
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
for
candidate_pred_anchor
in
anchor_pred_asym_ids
:
anchor_pred_mask
=
pred_ca_mask
[
0
][
asym_mask
[
0
]]
asym_mask
=
(
feature
[
"asym_id"
]
==
candidate_pred_anchor
).
bool
()
anchor_gt_residue
=
per_asym_residue_index
[
int
(
candidate_pred_anchor
)]
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
r
,
x
=
get_optimal_transform
(
anchor_gt_idx
,
anchor_gt_residue
,
anchor_pred_pos
,
anchor_true_pos
[
0
],
true_ca_masks
,
pred_ca_mask
,
mask
=
input_mask
[
0
]
asym_mask
,
)
pred_ca_pos
del
input_mask
# just to save memory
)
del
anchor_pred_mask
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
anchor_true_mask
align
=
greedy_align
(
gc
.
collect
()
feature
,
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
per_asym_residue_index
,
del
true_ca_poses
entity_2_asym_list
,
gc
.
collect
()
pred_ca_pos
,
align
=
greedy_align
(
pred_ca_mask
,
batch
,
aligned_true_ca_poses
,
per_asym_residue_index
,
true_ca_masks
,
unique_asym_ids
,
)
entity_2_asym_list
,
merged_labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
pred_ca_pos
,
original_nres
=
feature
[
'aatype'
].
shape
[
-
1
])
pred_ca_mask
,
rmsd
=
compute_rmsd
(
true_atom_pos
=
merged_labels
[
'all_atom_positions'
][...,
ca_idx
,
:].
to
(
r
.
dtype
)
@
r
+
x
,
aligned_true_ca_poses
,
pred_atom_pos
=
pred_ca_pos
,
true_ca_masks
,
atom_mask
=
(
pred_ca_mask
*
merged_labels
[
'all_atom_mask'
][...,
ca_idx
].
long
()).
bool
())
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
del
aligned_true_ca_poses
,
true_ca_masks
best_align
=
align
del
r
,
x
del
r
,
x
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
del
pred_ca_pos
,
pred_ca_mask
del
anchor_pred_pos
,
anchor_true_pos
gc
.
collect
()
gc
.
collect
()
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
else
:
else
:
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
return
align
,
per_asym_residue_index
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
best_align
=
list
(
enumerate
(
range
(
len
(
labels
))))
return
best_align
,
per_asym_residue_index
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
"""
"""
Overwrite AlphaFoldLoss forward function so that
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
it first compute multi-chain permutation
...
@@ -2156,22 +2213,25 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2156,22 +2213,25 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
batch: a pair of input features and its corresponding ground truth structure
"""
"""
# first check if it is a monomer
# first check if it is a monomer
features
,
ground_truth
=
batch
del
batch
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
if
not
is_monomer
:
if
not
is_monomer
:
permutate_chains
=
True
permutate_chains
=
True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
permutate_chains
=
permutate_chains
)
(
features
,
ground_truth
),
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
REQUIRED_FEATURES
=
[
i
for
i
in
ground_truth
.
keys
()])
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
# reorder ground truth labels according to permutation results
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
features
.
update
(
labels
)
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
if
(
not
_return_breakdown
):
if
(
not
_return_breakdown
):
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
...
...
tests/test_permutation.py
View file @
0ca66146
...
@@ -102,9 +102,10 @@ class TestPermutation(unittest.TestCase):
...
@@ -102,9 +102,10 @@ class TestPermutation(unittest.TestCase):
batch
[
'all_atom_mask'
]
=
true_atom_mask
batch
[
'all_atom_mask'
]
=
true_atom_mask
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
_
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
dim_dict
,
permutate_chains
=
True
)
permutate_chains
=
True
)
print
(
f
"##### aligns is
{
aligns
}
"
)
possible_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
3
),(
3
,
4
),(
4
,
2
)],[(
0
,
0
),(
1
,
1
),(
2
,
3
),(
3
,
4
),(
4
,
2
)]]
possible_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
3
),(
3
,
4
),(
4
,
2
)],[(
0
,
0
),(
1
,
1
),(
2
,
3
),(
3
,
4
),(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
4
),(
3
,
2
),(
4
,
3
)],[(
0
,
0
),(
1
,
1
),(
2
,
2
),(
3
,
3
),(
4
,
4
)]]
wrong_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
4
),(
3
,
2
),(
4
,
3
)],[(
0
,
0
),(
1
,
1
),(
2
,
2
),(
3
,
3
),(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertIn
(
aligns
,
possible_outcome
)
...
@@ -151,15 +152,15 @@ class TestPermutation(unittest.TestCase):
...
@@ -151,15 +152,15 @@ class TestPermutation(unittest.TestCase):
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
batch
,
dim_dict
,
dim_dict
,
permutate_chains
=
True
)
permutate_chains
=
True
)
print
(
f
"##### aligns is
{
aligns
}
"
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
labels
=
merge_labels
(
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
...
...
train_openfold.py
View file @
0ca66146
...
@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
...
@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# Log it
# Log it
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
features
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
self
.
ema
.
to
(
features
[
"aatype"
].
device
)
# Run the model
# Run the model
outputs
=
self
(
batch
)
outputs
=
self
(
features
)
# Remove the recycling dimension
# Remove the recycling dimension
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# Compute loss
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
outputs
,
(
features
,
gt_features
)
,
_return_breakdown
=
True
)
)
# Log it
# Log it
self
.
_log
(
loss_breakdown
,
batch
,
outputs
)
self
.
_log
(
loss_breakdown
,
features
,
outputs
)
return
loss
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# model.state_dict() contains references to model weights rather
...
@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
...
@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
# Run the model
outputs
=
self
(
batch
)
outputs
=
self
(
features
)
# Compute loss and other metrics
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
features
[
"use_clamped_fape"
]
=
0.
_
,
loss_breakdown
=
self
.
loss
(
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
outputs
,
(
features
,
gt_features
)
,
_return_breakdown
=
True
)
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
self
.
_log
(
loss_breakdown
,
features
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
# Restore the model weights to normal
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment