Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a18f98cf
Commit
a18f98cf
authored
Aug 10, 2023
by
Christina Floristean
Browse files
Seed fixes for multimer data pipeline
parent
44b0bf76
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
57 additions
and
17 deletions
+57
-17
openfold/config.py
openfold/config.py
+12
-4
openfold/data/data_modules.py
openfold/data/data_modules.py
+4
-6
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+2
-1
openfold/data/data_transforms_multimer.py
openfold/data/data_transforms_multimer.py
+6
-3
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+33
-3
No files found.
openfold/config.py
View file @
a18f98cf
...
@@ -155,8 +155,6 @@ def model_config(
...
@@ -155,8 +155,6 @@ def model_config(
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
elif
"multimer"
in
name
:
c
.
update
(
multimer_config_update
.
copy_and_resolve_references
())
c
.
update
(
multimer_config_update
.
copy_and_resolve_references
())
del
c
.
model
.
template
.
template_pointwise_attention
del
c
.
loss
.
fape
.
backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
...
@@ -354,6 +352,8 @@ config = mlc.ConfigDict(
...
@@ -354,6 +352,8 @@ config = mlc.ConfigDict(
"max_templates"
:
4
,
"max_templates"
:
4
,
"crop"
:
False
,
"crop"
:
False
,
"crop_size"
:
None
,
"crop_size"
:
None
,
"spatial_crop_prob"
:
None
,
"interface_threshold"
:
None
,
"supervised"
:
False
,
"supervised"
:
False
,
"uniform_recycling"
:
False
,
"uniform_recycling"
:
False
,
},
},
...
@@ -367,6 +367,8 @@ config = mlc.ConfigDict(
...
@@ -367,6 +367,8 @@ config = mlc.ConfigDict(
"max_templates"
:
4
,
"max_templates"
:
4
,
"crop"
:
False
,
"crop"
:
False
,
"crop_size"
:
None
,
"crop_size"
:
None
,
"spatial_crop_prob"
:
None
,
"interface_threshold"
:
None
,
"supervised"
:
True
,
"supervised"
:
True
,
"uniform_recycling"
:
False
,
"uniform_recycling"
:
False
,
},
},
...
@@ -381,6 +383,8 @@ config = mlc.ConfigDict(
...
@@ -381,6 +383,8 @@ config = mlc.ConfigDict(
"shuffle_top_k_prefiltered"
:
20
,
"shuffle_top_k_prefiltered"
:
20
,
"crop"
:
True
,
"crop"
:
True
,
"crop_size"
:
256
,
"crop_size"
:
256
,
"spatial_crop_prob"
:
0.
,
"interface_threshold"
:
None
,
"supervised"
:
True
,
"supervised"
:
True
,
"clamp_prob"
:
0.9
,
"clamp_prob"
:
0.9
,
"max_distillation_msa_clusters"
:
1000
,
"max_distillation_msa_clusters"
:
1000
,
...
@@ -709,7 +713,9 @@ multimer_config_update = mlc.ConfigDict({
...
@@ -709,7 +713,9 @@ multimer_config_update = mlc.ConfigDict({
"train"
:
{
"train"
:
{
"max_msa_clusters"
:
508
,
"max_msa_clusters"
:
508
,
"max_extra_msa"
:
2048
,
"max_extra_msa"
:
2048
,
"crop_size"
:
640
"crop_size"
:
640
,
"spatial_crop_prob"
:
0.5
,
"interface_threshold"
:
10.
},
},
},
},
"model"
:
{
"model"
:
{
...
@@ -735,6 +741,7 @@ multimer_config_update = mlc.ConfigDict({
...
@@ -735,6 +741,7 @@ multimer_config_update = mlc.ConfigDict({
"tri_mul_first"
:
True
,
"tri_mul_first"
:
True
,
"fuse_projection_weights"
:
True
"fuse_projection_weights"
:
True
},
},
"template_pointwise_attention"
:
None
,
# Not used in Multimer
"c_t"
:
c_t
,
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"use_unit_vector"
:
True
"use_unit_vector"
:
True
...
@@ -778,7 +785,8 @@ multimer_config_update = mlc.ConfigDict({
...
@@ -778,7 +785,8 @@ multimer_config_update = mlc.ConfigDict({
"clamp_distance"
:
30.0
,
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
"weight"
:
0.5
}
},
"backbone"
:
None
# Not used in Multimer
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"num_classes"
:
22
"num_classes"
:
22
...
...
openfold/data/data_modules.py
View file @
a18f98cf
...
@@ -4,7 +4,7 @@ import json
...
@@ -4,7 +4,7 @@ import json
import
logging
import
logging
import
os
import
os
import
pickle
import
pickle
from
typing
import
Optional
,
Sequence
,
List
,
Any
from
typing
import
Optional
,
Sequence
,
Any
import
ml_collections
as
mlc
import
ml_collections
as
mlc
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
...
@@ -880,10 +880,7 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
...
@@ -880,10 +880,7 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
config
=
config
self
.
stage
=
stage
self
.
stage
=
stage
if
(
generator
is
None
):
generator
=
torch
.
Generator
()
self
.
generator
=
generator
self
.
generator
=
generator
print
(
'initialised a multimer dataloader'
)
print
(
'initialised a multimer dataloader'
)
...
@@ -1220,8 +1217,9 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
...
@@ -1220,8 +1217,9 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
)
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
):
generator
=
torch
.
Generator
()
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
if
(
self
.
batch_seed
is
not
None
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
dataset
=
None
...
...
openfold/data/data_transforms.py
View file @
a18f98cf
...
@@ -477,8 +477,9 @@ def make_masked_msa(protein, config, replace_fraction, seed):
...
@@ -477,8 +477,9 @@ def make_masked_msa(protein, config, replace_fraction, seed):
sh
=
protein
[
"msa"
].
shape
sh
=
protein
[
"msa"
].
shape
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
g
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
sample
=
torch
.
rand
(
sh
,
device
=
device
,
generator
=
g
)
sample
=
torch
.
rand
(
sh
,
device
=
device
,
generator
=
g
)
...
...
openfold/data/data_transforms_multimer.py
View file @
a18f98cf
...
@@ -100,8 +100,9 @@ def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
...
@@ -100,8 +100,9 @@ def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
logits
=
torch
.
log
(
categorical_probs
+
eps
)
logits
=
torch
.
log
(
categorical_probs
+
eps
)
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
g
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
bert_msa
=
gumbel_max_sample
(
logits
,
generator
=
g
)
bert_msa
=
gumbel_max_sample
(
logits
,
generator
=
g
)
...
@@ -262,8 +263,9 @@ def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
...
@@ -262,8 +263,9 @@ def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
Returns:
Returns:
Protein with sampled msa.
Protein with sampled msa.
"""
"""
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
g
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
# Sample uniformly among sequences with at least one non-masked position.
# Sample uniformly among sequences with at least one non-masked position.
...
@@ -417,8 +419,9 @@ def random_crop_to_size(
...
@@ -417,8 +419,9 @@ def random_crop_to_size(
):
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
g
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
use_spatial_crop
=
torch
.
rand
((
1
,),
use_spatial_crop
=
torch
.
rand
((
1
,),
...
...
openfold/data/input_pipeline_multimer.py
View file @
a18f98cf
...
@@ -13,8 +13,7 @@
...
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
random
import
torch
import
torch
from
openfold.data
import
(
from
openfold.data
import
(
...
@@ -75,13 +74,44 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
...
@@ -75,13 +74,44 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
transforms
.
append
(
data_transforms_multimer
.
nearest_neighbor_clusters
())
transforms
.
append
(
data_transforms_multimer
.
nearest_neighbor_clusters
())
transforms
.
append
(
data_transforms_multimer
.
create_msa_feat
)
transforms
.
append
(
data_transforms_multimer
.
create_msa_feat
)
crop_feats
=
dict
(
common_cfg
.
feat
)
if
mode_cfg
.
fixed_size
:
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
if
mode_cfg
.
crop
:
transforms
.
append
(
data_transforms_multimer
.
random_crop_to_size
(
crop_size
=
mode_cfg
.
crop_size
,
max_templates
=
mode_cfg
.
max_templates
,
shape_schema
=
crop_feats
,
spatial_crop_prob
=
mode_cfg
.
spatial_crop_prob
,
interface_threshold
=
mode_cfg
.
interface_threshold
,
subsample_templates
=
mode_cfg
.
subsample_templates
,
seed
=
ensemble_seed
+
1
,
)
)
transforms
.
append
(
data_transforms
.
make_fixed_size
(
shape_schema
=
crop_feats
,
msa_cluster_size
=
pad_msa_clusters
,
extra_msa_size
=
mode_cfg
.
max_extra_msa
,
num_res
=
mode_cfg
.
crop_size
,
num_templates
=
mode_cfg
.
max_templates
,
)
)
else
:
transforms
.
append
(
data_transforms
.
crop_templates
(
mode_cfg
.
max_templates
)
)
return
transforms
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
(
)
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
def
wrap_ensemble_fn
(
data
,
i
):
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
"""Function to be mapped over the ensemble dimension."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment