Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
33d8de81
"examples/python/vscode:/vscode.git/clone" did not exist on "4017bd18d0e84b8463bfe381279d4c5a6fd0c6e0"
Commit
33d8de81
authored
Jul 09, 2023
by
Geoffrey Yu
Browse files
added multimer training filter criteria described in the multimer paper
parent
d35816e3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
109 additions
and
5 deletions
+109
-5
openfold/data/data_modules.py
openfold/data/data_modules.py
+109
-5
No files found.
openfold/data/data_modules.py
View file @
33d8de81
...
@@ -11,7 +11,7 @@ import numpy as np
...
@@ -11,7 +11,7 @@ import numpy as np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch.utils.data
import
RandomSampler
from
torch.utils.data
import
RandomSampler
from
openfold.np.residue_constants
import
restypes
from
openfold.data
import
(
from
openfold.data
import
(
data_pipeline
,
data_pipeline
,
feature_pipeline
,
feature_pipeline
,
...
@@ -579,6 +579,40 @@ def deterministic_train_filter(
...
@@ -579,6 +579,40 @@ def deterministic_train_filter(
return
True
return
True
def
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
:
9.
,
max_single_aa_prop
:
float
=
0.8
,
minimum_number_of_residues
:
int
=
200
,
)
->
bool
:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
# First check resolution
resolution
=
mmcif_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
):
return
False
# Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs
=
mmcif_data_cache_entry
[
"seqs"
]
counts
=
{}
for
aa
in
restypes
:
counts
[
aa
]
=
0
total_len
=
sum
([
len
(
i
)
for
i
in
seqs
])
if
total_len
<
minimum_number_of_residues
:
# check if the complex has less than 200 residues
return
False
for
seq
in
seqs
:
for
aa
in
seq
:
counts
[
aa
]
+=
1
largest_aa_count
=
max
(
counts
.
values
())
largest_single_aa_prop
=
largest_aa_count
/
total_len
if
(
largest_single_aa_prop
>
max_single_aa_prop
):
return
False
return
True
def
get_stochastic_train_filter_prob
(
def
get_stochastic_train_filter_prob
(
chain_data_cache_entry
:
Any
,
chain_data_cache_entry
:
Any
,
...
@@ -694,12 +728,83 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -694,12 +728,83 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
print
(
f
"datapoints is
{
self
.
datapoints
}
"
)
print
(
f
"datapoints is
{
self
.
datapoints
}
"
)
class
OpenFoldMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
def
__init__
(
self
,
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
float
],
epoch_len
:
int
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
self
.
datasets
=
datasets
self
.
probabilities
=
probabilities
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
_samples
=
[
self
.
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
def
looped_shuffled_dataset_idx
(
self
,
dataset_len
):
while
True
:
# Uniformly shuffle each dataset's indices
weights
=
[
1.
for
_
in
range
(
dataset_len
)]
shuf
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
dataset_len
,
replacement
=
False
,
generator
=
self
.
generator
,
)
for
idx
in
shuf
:
yield
idx
def
looped_samples
(
self
,
dataset_idx
):
max_cache_len
=
int
(
self
.
epoch_len
*
self
.
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
self
.
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
dataset
.
chain_data_cache
mmcif_data_cache
=
dataset
.
mmcif_data_cache
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
## TO DO: add filtering cretieria for multimer
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
candidate_idx
)
chains
=
chain_data_cache
[
mmcif_id
][
'chain_ids'
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
(
not
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
)):
continue
p
=
get_stochastic_train_filter_prob
(
chain_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
self
.
generator
,
)
samples
=
samples
.
squeeze
()
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
for
datapoint_idx
in
cache
:
yield
datapoint_idx
class
OpenFoldBatchCollator
:
class
OpenFoldBatchCollator
:
def
__call__
(
self
,
prots
):
def
__call__
(
self
,
prots
):
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
return
dict_multimap
(
stack_fn
,
prots
)
return
dict_multimap
(
stack_fn
,
prots
)
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
...
@@ -796,7 +901,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -796,7 +901,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_data_dir
:
Optional
[
str
]
=
None
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
train_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
...
@@ -824,7 +928,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -824,7 +928,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
train_data_dir
=
train_data_dir
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_chain_data_cache_path
=
train_chain_data_cache_path
self
.
train_chain_data_cache_path
=
train_chain_data_cache_path
self
.
train_mmcif_data_cache_path
=
train_mmcif_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_chain_data_cache_path
=
(
self
.
distillation_chain_data_cache_path
=
(
...
@@ -1045,6 +1148,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
...
@@ -1045,6 +1148,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_dataset
=
dataset_gen
(
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
data_dir
=
self
.
train_data_dir
,
mmcif_data_cache_path
=
self
.
train_mmcif_data_cache_path
,
mmcif_data_cache_path
=
self
.
train_mmcif_data_cache_path
,
chain_data_cache_path
=
self
.
train_chain_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
@@ -1084,7 +1188,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
...
@@ -1084,7 +1188,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
+
1
)
generator
=
generator
.
manual_seed
(
self
.
batch_seed
+
1
)
self
.
train_dataset
=
OpenFoldDataset
(
self
.
train_dataset
=
OpenFold
Multimer
Dataset
(
datasets
=
datasets
,
datasets
=
datasets
,
probabilities
=
probabilities
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
epoch_len
=
self
.
train_epoch_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment