Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
cfd0fc6e
Unverified
Commit
cfd0fc6e
authored
Feb 03, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Feb 03, 2022
Browse files
Merge pull request #76 from aqlaboratory/chunking_experiment_rebased
parents
c9e0f894
2726892a
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1199 additions
and
610 deletions
+1199
-610
openfold/config.py
openfold/config.py
+5
-3
openfold/data/data_modules.py
openfold/data/data_modules.py
+208
-65
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+82
-30
openfold/model/embedders.py
openfold/model/embedders.py
+3
-3
openfold/model/evoformer.py
openfold/model/evoformer.py
+229
-88
openfold/model/heads.py
openfold/model/heads.py
+2
-2
openfold/model/model.py
openfold/model/model.py
+46
-38
openfold/model/msa.py
openfold/model/msa.py
+104
-36
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+2
-2
openfold/model/primitives.py
openfold/model/primitives.py
+209
-169
openfold/model/structure_module.py
openfold/model/structure_module.py
+13
-11
openfold/model/template.py
openfold/model/template.py
+4
-5
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+6
-7
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+3
-3
openfold/utils/checkpointing.py
openfold/utils/checkpointing.py
+13
-6
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+14
-8
openfold/utils/loss.py
openfold/utils/loss.py
+76
-28
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+11
-2
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+0
-66
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+169
-38
No files found.
openfold/config.py
View file @
cfd0fc6e
...
@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int)
...
@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
blocks_per_ckpt
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
tm_enabled
=
mlc
.
FieldReference
(
False
,
field_type
=
bool
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
...
@@ -228,7 +229,7 @@ config = mlc.ConfigDict(
...
@@ -228,7 +229,7 @@ config = mlc.ConfigDict(
"use_small_bfd"
:
False
,
"use_small_bfd"
:
False
,
"data_loaders"
:
{
"data_loaders"
:
{
"batch_size"
:
1
,
"batch_size"
:
1
,
"num_workers"
:
8
,
"num_workers"
:
16
,
},
},
},
},
},
},
...
@@ -320,10 +321,10 @@ config = mlc.ConfigDict(
...
@@ -320,10 +321,10 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
"ckpt"
:
blocks_per_ckpt
is
not
None
,
},
},
"enabled"
:
True
,
"enabled"
:
True
,
},
},
...
@@ -376,7 +377,7 @@ config = mlc.ConfigDict(
...
@@ -376,7 +377,7 @@ config = mlc.ConfigDict(
"tm"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
False
,
"enabled"
:
tm_enabled
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
@@ -454,6 +455,7 @@ config = mlc.ConfigDict(
...
@@ -454,6 +455,7 @@ config = mlc.ConfigDict(
"max_resolution"
:
3.0
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.0
,
"weight"
:
0.0
,
"enabled"
:
tm_enabled
,
},
},
"eps"
:
eps
,
"eps"
:
eps
,
},
},
...
...
openfold/data/data_modules.py
View file @
cfd0fc6e
...
@@ -4,7 +4,7 @@ import json
...
@@ -4,7 +4,7 @@ import json
import
logging
import
logging
import
os
import
os
import
pickle
import
pickle
from
typing
import
Optional
,
Sequence
from
typing
import
Optional
,
Sequence
,
List
,
Any
import
ml_collections
as
mlc
import
ml_collections
as
mlc
import
numpy
as
np
import
numpy
as
np
...
@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_date
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
config
:
mlc
.
ConfigDict
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
mapping_path
:
Optional
[
str
]
=
None
,
max_template_hits
:
int
=
4
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
treat_pdb_as_distillation
:
bool
=
True
,
treat_pdb_as_distillation
:
bool
=
True
,
mapping_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
_output_raw
:
bool
=
False
,
_output_raw
:
bool
=
False
,
_alignment_index
:
Optional
[
Any
]
=
None
):
):
"""
"""
Args:
Args:
...
@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
A dataset config object. See openfold.config
A dataset config object. See openfold.config
kalign_binary_path:
kalign_binary_path:
Path to kalign binary.
Path to kalign binary.
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement.
max_template_hits:
max_template_hits:
An upper bound on how many templates are considered. During
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
training, the templates ultimately used are subsampled
...
@@ -89,26 +84,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -89,26 +84,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
mode
=
mode
self
.
mode
=
mode
self
.
_output_raw
=
_output_raw
self
.
_output_raw
=
_output_raw
self
.
_alignment_index
=
_alignment_index
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
if
(
mode
not
in
valid_modes
):
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
if
(
mapping_path
is
None
):
self
.
mapping
=
{
str
(
i
):
os
.
path
.
splitext
(
name
)[
0
]
for
i
,
name
in
enumerate
(
os
.
listdir
(
alignment_dir
))
}
else
:
with
open
(
mapping_path
,
'r'
)
as
fp
:
self
.
mapping
=
json
.
load
(
fp
)
if
(
template_release_dates_cache_path
is
None
):
if
(
template_release_dates_cache_path
is
None
):
logging
.
warning
(
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
"scripts/generate_mmcif_cache.py before running OpenFold"
)
)
if
(
_alignment_index
is
not
None
):
self
.
_chain_ids
=
list
(
_alignment_index
.
keys
())
elif
(
mapping_path
is
None
):
self
.
_chain_ids
=
list
(
os
.
listdir
(
alignment_dir
))
else
:
with
open
(
mapping_path
,
"r"
)
as
f
:
self
.
_chain_ids
=
[
l
.
strip
()
for
l
in
f
.
readlines
()]
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
}
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
template_mmcif_dir
,
mmcif_dir
=
template_mmcif_dir
,
max_template_date
=
max_template_date
,
max_template_date
=
max_template_date
,
...
@@ -126,7 +125,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -126,7 +125,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if
(
not
self
.
_output_raw
):
if
(
not
self
.
_output_raw
):
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
):
with
open
(
path
,
'r'
)
as
f
:
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
mmcif_string
=
f
.
read
()
...
@@ -145,14 +144,26 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -145,14 +144,26 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif
=
mmcif_object
,
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
chain_id
=
chain_id
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
)
)
return
data
return
data
def
chain_id_to_idx
(
self
,
chain_id
):
return
self
.
_chain_id_to_idx_dict
[
chain_id
]
def
idx_to_chain_id
(
self
,
idx
):
return
self
.
_chain_ids
[
idx
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
name
=
self
.
mapping
[
str
(
idx
)
]
name
=
self
.
idx_to_chain_id
(
idx
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
_alignment_index
=
None
if
(
self
.
_alignment_index
is
not
None
):
alignment_dir
=
self
.
alignment_dir
_alignment_index
=
self
.
_alignment_index
[
name
]
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
spl
=
name
.
rsplit
(
'_'
,
1
)
spl
=
name
.
rsplit
(
'_'
,
1
)
if
(
len
(
spl
)
==
2
):
if
(
len
(
spl
)
==
2
):
...
@@ -164,11 +175,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -164,11 +175,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
if
(
os
.
path
.
exists
(
path
+
".cif"
)):
if
(
os
.
path
.
exists
(
path
+
".cif"
)):
data
=
self
.
_parse_mmcif
(
data
=
self
.
_parse_mmcif
(
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
,
)
)
elif
(
os
.
path
.
exists
(
path
+
".core"
)):
elif
(
os
.
path
.
exists
(
path
+
".core"
)):
data
=
self
.
data_pipeline
.
process_core
(
data
=
self
.
data_pipeline
.
process_core
(
path
+
".core"
,
alignment_dir
path
+
".core"
,
alignment_dir
,
_alignment_index
,
)
)
elif
(
os
.
path
.
exists
(
path
+
".pdb"
)):
elif
(
os
.
path
.
exists
(
path
+
".pdb"
)):
data
=
self
.
data_pipeline
.
process_pdb
(
data
=
self
.
data_pipeline
.
process_pdb
(
...
@@ -176,6 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -176,6 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
chain_id
=
chain_id
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
,
)
)
else
:
else
:
raise
ValueError
(
"Invalid file type"
)
raise
ValueError
(
"Invalid file type"
)
...
@@ -184,6 +196,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -184,6 +196,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
=
self
.
data_pipeline
.
process_fasta
(
data
=
self
.
data_pipeline
.
process_fasta
(
fasta_path
=
path
,
fasta_path
=
path
,
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
_alignment_index
=
_alignment_index
,
)
)
if
(
self
.
_output_raw
):
if
(
self
.
_output_raw
):
...
@@ -196,53 +209,150 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -196,53 +209,150 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return
feats
return
feats
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
mapping
.
keys
())
return
len
(
self
.
_chain_ids
)
def
deterministic_train_filter
(
prot_data_cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
# Hard filters
resolution
=
prot_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
):
return
False
seq
=
prot_data_cache_entry
[
"seq"
]
counts
=
{}
for
aa
in
seq
:
counts
.
setdefault
(
aa
,
0
)
counts
[
aa
]
+=
1
largest_aa_count
=
max
(
counts
.
values
())
largest_single_aa_prop
=
largest_aa_count
/
len
(
seq
)
if
(
largest_single_aa_prop
>
max_single_aa_prop
):
return
False
return
True
def
get_stochastic_train_filter_prob
(
prot_data_cache_entry
:
Any
,
)
->
List
[
float
]:
# Stochastic filters
probabilities
=
[]
cluster_size
=
prot_data_cache_entry
.
get
(
"cluster_size"
,
None
)
if
(
cluster_size
is
not
None
and
cluster_size
>
0
):
probabilities
.
append
(
1
/
cluster_size
)
chain_length
=
len
(
prot_data_cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
# Risk of underflow here?
out
=
1
for
p
in
probabilities
:
out
*=
p
def
looped_sequence
(
sequence
):
return
out
while
True
:
for
x
in
sequence
:
yield
x
class
OpenFoldDataset
(
torch
.
utils
.
data
.
Iterable
Dataset
):
class
OpenFoldDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
"""
The Dataset is written to accommodate the requirement that proteins are
Implements the stochastic filters applied during AlphaFold's training.
sampled from the distillation set with some probability p
Because samples are selected from constituent datasets randomly, the
and from the PDB set with probability (1 - p). Proteins are sampled
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
from both sets without replacement, and as soon as either set is
and filtered once at initialization.
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
datasets
:
Sequence
[
OpenFoldSingleDataset
],
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
epoch_len
:
int
,
prot_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
):
self
.
datasets
=
datasets
self
.
datasets
=
datasets
self
.
samplers
=
[
self
.
probabilities
=
probabilities
looped_sequence
(
RandomSampler
(
d
))
for
d
in
datasets
]
self
.
epoch_len
=
epoch_len
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
prot_data_caches
=
[]
for
path
in
prot_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
prot_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
# Uniformly shuffle each dataset's indices
weights
=
[
1.
for
_
in
range
(
dataset_len
)]
shuf
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
dataset_len
,
replacement
=
False
,
generator
=
self
.
generator
,
)
for
idx
in
shuf
:
yield
idx
def
looped_samples
(
dataset_idx
):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
prot_data_cache
=
self
.
prot_data_caches
[
dataset_idx
]
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
prot_data_cache_entry
=
prot_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
prot_data_cache_entry
)):
continue
p
=
get_stochastic_train_filter_prob
(
prot_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
self
.
generator
,
)
samples
=
samples
.
squeeze
()
self
.
distr
=
torch
.
distributions
.
categorical
.
Categorical
(
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
probs
=
torch
.
tensor
(
probabilities
),
)
def
__iter__
(
self
):
for
datapoint_idx
in
cache
:
return
self
yield
datapoint_idx
self
.
_samples
=
[
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
def
__next__
(
self
):
if
(
_roll_at_init
):
dataset_idx
=
self
.
distr
.
sample
()
self
.
reroll
()
sampler
=
self
.
samplers
[
dataset_idx
]
element_idx
=
next
(
sampler
)
def
__getitem__
(
self
,
idx
):
return
self
.
datasets
[
dataset_idx
][
element_idx
]
dataset_idx
,
datapoint_idx
=
self
.
datapoints
[
idx
]
return
self
.
datasets
[
dataset_idx
][
datapoint_idx
]
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
epoch_len
return
self
.
epoch_len
def
reroll
(
self
):
dataset_choices
=
torch
.
multinomial
(
torch
.
tensor
(
self
.
probabilities
),
num_samples
=
self
.
epoch_len
,
replacement
=
True
,
generator
=
self
.
generator
,
)
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
samples
=
self
.
_samples
[
dataset_idx
]
datapoint_idx
=
next
(
samples
)
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
class
OpenFoldBatchCollator
:
class
OpenFoldBatchCollator
:
def
__init__
(
self
,
config
,
stage
=
"train"
):
def
__init__
(
self
,
config
,
stage
=
"train"
):
...
@@ -283,7 +393,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -283,7 +393,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs
.
append
(
keyed_probs
.
append
(
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
)
)
if
(
stage_cfg
.
uniform_recycling
):
if
(
stage_cfg
.
uniform_recycling
):
recycling_probs
=
[
recycling_probs
=
[
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
...
@@ -293,7 +403,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -293,7 +403,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
0.
for
_
in
range
(
max_iters
+
1
)
0.
for
_
in
range
(
max_iters
+
1
)
]
]
recycling_probs
[
-
1
]
=
1.
recycling_probs
[
-
1
]
=
1.
keyed_probs
.
append
(
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
(
"no_recycling_iters"
,
recycling_probs
)
)
)
...
@@ -361,8 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -361,8 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_prot_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_prot_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
...
@@ -373,6 +485,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -373,6 +485,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
**
kwargs
):
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
super
(
OpenFoldDataModule
,
self
).
__init__
()
...
@@ -382,8 +496,12 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -382,8 +496,12 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
max_template_date
=
max_template_date
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_prot_data_cache_path
=
train_prot_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_prot_data_cache_path
=
(
distillation_prot_data_cache_path
)
self
.
val_data_dir
=
val_data_dir
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
predict_data_dir
=
predict_data_dir
self
.
predict_data_dir
=
predict_data_dir
...
@@ -396,6 +514,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -396,6 +514,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
self
.
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
self
.
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
self
.
batch_seed
=
batch_seed
self
.
batch_seed
=
batch_seed
self
.
train_epoch_len
=
train_epoch_len
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
...
@@ -405,11 +524,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -405,11 +524,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
self
.
train_alignment_dir
is
None
):
if
(
self
.
training_mode
and
train_alignment_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
'In training mode, train_alignment_dir must be specified'
)
)
elif
(
not
self
.
training_mode
and
self
.
predict_ali
n
gment_dir
is
None
):
elif
(
not
self
.
training_mode
and
predict_alig
n
ment_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
'In inference mode, predict_alignment_dir must be specified'
)
)
...
@@ -419,10 +538,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -419,10 +538,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
'be specified as well'
)
)
def
setup
(
self
,
stage
:
Optional
[
str
]
=
None
):
# An ad-hoc measure for our particular filesystem restrictions
if
(
stage
is
None
):
self
.
_alignment_index
=
None
stage
=
"train"
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
@@ -435,8 +557,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -435,8 +557,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
obsolete_pdbs_file_path
,
self
.
obsolete_pdbs_file_path
,
)
)
if
(
self
.
training_mode
):
if
(
self
.
training_mode
):
self
.
train_dataset
=
dataset_gen
(
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
data_dir
=
self
.
train_data_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
mapping_path
=
self
.
train_mapping_path
,
mapping_path
=
self
.
train_mapping_path
,
...
@@ -446,8 +568,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -446,8 +568,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation
=
False
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
_output_raw
=
True
,
_alignment_index
=
self
.
_alignment_index
,
)
)
distillation_dataset
=
None
if
(
self
.
distillation_data_dir
is
not
None
):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
data_dir
=
self
.
distillation_data_dir
,
...
@@ -460,13 +584,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -460,13 +584,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
self
.
config
.
train
.
distillation_prob
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
[
self
.
train_dataset
,
distillation_dataset
],
if
(
distillation_dataset
is
not
None
):
probabilities
=
[
1
-
d_prob
,
d_prob
],
datasets
=
[
train_dataset
,
distillation_dataset
]
epoch_len
=
(
d_prob
=
self
.
config
.
train
.
distillation_prob
self
.
train_dataset
.
len
()
+
distillation_dataset
.
len
()
probabilities
=
[
1
-
d_prob
,
d_prob
]
),
prot_data_cache_paths
=
[
)
self
.
train_prot_data_cache_path
,
self
.
distillation_prot_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
prot_data_cache_paths
=
[
self
.
train_prot_data_cache_path
,
]
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
prot_data_cache_paths
=
prot_data_cache_paths
,
_roll_at_init
=
False
,
)
if
(
self
.
val_data_dir
is
not
None
):
if
(
self
.
val_data_dir
is
not
None
):
self
.
eval_dataset
=
dataset_gen
(
self
.
eval_dataset
=
dataset_gen
(
...
@@ -496,6 +636,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -496,6 +636,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset
=
None
dataset
=
None
if
(
stage
==
"train"
):
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
elif
(
stage
==
"eval"
):
dataset
=
self
.
eval_dataset
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
):
elif
(
stage
==
"predict"
):
...
...
openfold/data/data_pipeline.py
View file @
cfd0fc6e
...
@@ -422,42 +422,89 @@ class DataPipeline:
...
@@ -422,42 +422,89 @@ class DataPipeline:
def
_parse_msa_data
(
def
_parse_msa_data
(
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msa_data
=
{}
msa_data
=
{}
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
if
(
_alignment_index
is
not
None
):
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
"rb"
)
if
(
ext
==
".a3m"
):
def
read_msa
(
start
,
size
):
with
open
(
path
,
"r"
)
as
fp
:
fp
.
seek
(
start
)
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
fp
.
read
())
msa
=
fp
.
read
(
size
).
decode
(
"utf-8"
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
return
msa
elif
(
ext
==
".sto"
):
with
open
(
path
,
"r"
)
as
fp
:
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".a3m"
):
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
fp
.
read
(
)
read
_msa
(
start
,
size
)
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
else
:
continue
continue
msa_data
[
name
]
=
data
msa_data
[
f
]
=
data
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".a3m"
):
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
fp
.
read
())
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
fp
.
read
()
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
continue
msa_data
[
f
]
=
data
return
msa_data
return
msa_data
def
_parse_template_hits
(
def
_parse_template_hits
(
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
all_hits
=
{}
for
f
in
os
.
listdir
(
alignment_dir
):
if
(
_alignment_index
is
not
None
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
'rb'
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".hhr"
):
def
read_template
(
start
,
size
):
with
open
(
path
,
"r"
)
as
fp
:
fp
.
seek
(
start
)
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
return
fp
.
read
(
size
).
decode
(
"utf-8"
)
all_hits
[
f
]
=
hits
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".hhr"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
return
all_hits
return
all_hits
...
@@ -465,8 +512,9 @@ class DataPipeline:
...
@@ -465,8 +512,9 @@ class DataPipeline:
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
)
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
if
(
len
(
msa_data
)
==
0
):
if
(
len
(
msa_data
)
==
0
):
if
(
input_sequence
is
None
):
if
(
input_sequence
is
None
):
...
@@ -496,6 +544,7 @@ class DataPipeline:
...
@@ -496,6 +544,7 @@ class DataPipeline:
self
,
self
,
fasta_path
:
str
,
fasta_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
with
open
(
fasta_path
)
as
f
:
...
@@ -509,7 +558,7 @@ class DataPipeline:
...
@@ -509,7 +558,7 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -522,7 +571,7 @@ class DataPipeline:
...
@@ -522,7 +571,7 @@ class DataPipeline:
num_res
=
num_res
,
num_res
=
num_res
,
)
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
return
{
return
{
**
sequence_features
,
**
sequence_features
,
...
@@ -535,6 +584,7 @@ class DataPipeline:
...
@@ -535,6 +584,7 @@ class DataPipeline:
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
Assembles features for a specific chain in an mmCIF object.
Assembles features for a specific chain in an mmCIF object.
...
@@ -552,7 +602,7 @@ class DataPipeline:
...
@@ -552,7 +602,7 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -560,7 +610,7 @@ class DataPipeline:
...
@@ -560,7 +610,7 @@ class DataPipeline:
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
])
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
])
)
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
return
{
**
mmcif_feats
,
**
template_features
,
**
msa_features
}
return
{
**
mmcif_feats
,
**
template_features
,
**
msa_features
}
...
@@ -570,6 +620,7 @@ class DataPipeline:
...
@@ -570,6 +620,7 @@ class DataPipeline:
alignment_dir
:
str
,
alignment_dir
:
str
,
is_distillation
:
bool
=
True
,
is_distillation
:
bool
=
True
,
chain_id
:
Optional
[
str
]
=
None
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
Assembles features for a protein in a PDB file.
Assembles features for a protein in a PDB file.
...
@@ -586,14 +637,14 @@ class DataPipeline:
...
@@ -586,14 +637,14 @@ class DataPipeline:
is_distillation
is_distillation
)
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
self
.
template_featurizer
,
self
.
template_featurizer
,
)
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
return
{
**
pdb_feats
,
**
template_features
,
**
msa_features
}
return
{
**
pdb_feats
,
**
template_features
,
**
msa_features
}
...
@@ -601,6 +652,7 @@ class DataPipeline:
...
@@ -601,6 +652,7 @@ class DataPipeline:
self
,
self
,
core_path
:
str
,
core_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
Assembles features for a protein in a ProteinNet .core file.
Assembles features for a protein in a ProteinNet .core file.
...
@@ -613,7 +665,7 @@ class DataPipeline:
...
@@ -613,7 +665,7 @@ class DataPipeline:
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
core_feats
=
make_protein_features
(
protein_object
,
description
)
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
...
openfold/model/embedders.py
View file @
cfd0fc6e
...
@@ -17,7 +17,7 @@ import torch
...
@@ -17,7 +17,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
one_hot
from
openfold.utils.tensor_utils
import
one_hot
...
@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module):
...
@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module):
self
.
bins
=
None
self
.
bins
=
None
self
.
linear
=
Linear
(
self
.
no_bins
,
self
.
c_z
)
self
.
linear
=
Linear
(
self
.
no_bins
,
self
.
c_z
)
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
def
forward
(
def
forward
(
self
,
self
,
...
...
openfold/model/evoformer.py
View file @
cfd0fc6e
...
@@ -13,12 +13,13 @@
...
@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
from
functools
import
partial
from
functools
import
partial
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.dropout
import
DropoutRowwise
,
DropoutColumnwise
from
openfold.model.dropout
import
DropoutRowwise
,
DropoutColumnwise
from
openfold.model.msa
import
(
from
openfold.model.msa
import
(
MSARowAttentionWithPairBias
,
MSARowAttentionWithPairBias
,
...
@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing
,
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
chunk_layer
...
@@ -60,7 +61,7 @@ class MSATransition(nn.Module):
...
@@ -60,7 +61,7 @@ class MSATransition(nn.Module):
self
.
c_m
=
c_m
self
.
c_m
=
c_m
self
.
n
=
n
self
.
n
=
n
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
layer_norm
=
LayerNorm
(
self
.
c_m
)
self
.
linear_1
=
Linear
(
self
.
c_m
,
self
.
n
*
self
.
c_m
,
init
=
"relu"
)
self
.
linear_1
=
Linear
(
self
.
c_m
,
self
.
n
*
self
.
c_m
,
init
=
"relu"
)
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
...
@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
...
@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
return
m
return
m
class
EvoformerBlock
(
nn
.
Module
):
class
EvoformerBlock
Core
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
):
):
super
(
EvoformerBlock
,
self
).
__init__
()
super
(
EvoformerBlockCore
,
self
).
__init__
()
self
.
_is_extra_msa_stack
=
_is_extra_msa_stack
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
if
_is_extra_msa_stack
:
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
else
:
self
.
msa_att_col
=
MSAColumnAttention
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_transition
=
MSATransition
(
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
c_m
=
c_m
,
...
@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
...
@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
transition_n
,
transition_n
,
)
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
...
@@ -213,17 +185,13 @@ class EvoformerBlock(nn.Module):
...
@@ -213,17 +185,13 @@ class EvoformerBlock(nn.Module):
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# should be disabled to better approximate the exact activations of
# the original.
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_transition
(
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
)
)
...
@@ -245,6 +213,175 @@ class EvoformerBlock(nn.Module):
...
@@ -245,6 +213,175 @@ class EvoformerBlock(nn.Module):
return
m
,
z
return
m
,
z
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnAttention
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
ckpt
:
bool
,
):
super
(
ExtraMSABlock
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
.
clone
(),
z
=
z
.
clone
(),
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_chunk_logits
=
_chunk_logits
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
)
def
fn
(
m
,
z
):
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
return
m
,
z
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
checkpoint_fn
(
fn
,
m
,
z
)
else
:
m
,
z
=
fn
(
m
,
z
)
return
m
,
z
class
EvoformerStack
(
nn
.
Module
):
class
EvoformerStack
(
nn
.
Module
):
"""
"""
Main Evoformer trunk.
Main Evoformer trunk.
...
@@ -271,7 +408,6 @@ class EvoformerStack(nn.Module):
...
@@ -271,7 +408,6 @@ class EvoformerStack(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -313,7 +449,6 @@ class EvoformerStack(nn.Module):
...
@@ -313,7 +449,6 @@ class EvoformerStack(nn.Module):
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
_is_extra_msa_stack
=
_is_extra_msa_stack
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
...
@@ -332,15 +467,12 @@ class EvoformerStack(nn.Module):
...
@@ -332,15 +467,12 @@ class EvoformerStack(nn.Module):
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
_is_extra_msa_stack
=
_is_extra_msa_stack
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
if
not
self
.
_is_extra_msa_stack
:
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
def
forward
(
def
forward
(
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
...
@@ -390,13 +522,8 @@ class EvoformerStack(nn.Module):
...
@@ -390,13 +522,8 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
s
=
None
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
if
not
self
.
_is_extra_msa_stack
:
seq_dim
=
-
3
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
s
=
s
.
squeeze
(
seq_dim
)
return
m
,
z
,
s
return
m
,
z
,
s
...
@@ -405,8 +532,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -405,8 +532,7 @@ class ExtraMSAStack(nn.Module):
Implements Algorithm 18.
Implements Algorithm 18.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_msa_att
:
int
,
...
@@ -419,38 +545,38 @@ class ExtraMSAStack(nn.Module):
...
@@ -419,38 +545,38 @@ class ExtraMSAStack(nn.Module):
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
super
(
ExtraMSAStack
,
self
).
__init__
()
super
(
ExtraMSAStack
,
self
).
__init__
()
c_s
=
None
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
stack
=
EvoformerStack
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
c_s
=
c_s
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
no_blocks
=
no_blocks
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
blocks_per_ckpt
=
blocks_per_ckpt
,
inf
=
inf
,
eps
=
eps
,
clear_cache_between_blocks
=
clear_cache_between_blocks
,
_is_extra_msa_stack
=
True
,
)
def
forward
(
self
.
blocks
=
nn
.
ModuleList
()
self
,
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
False
,
)
self
.
blocks
.
append
(
block
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
...
@@ -470,13 +596,28 @@ class ExtraMSAStack(nn.Module):
...
@@ -470,13 +596,28 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask
Optional [*, N_res, N_res] pair mask
Returns:
Returns:
[*, N_res, N_res, C_z] pair update
[*, N_res, N_res, C_z] pair update
"""
"""
_
,
z
,
_
=
self
.
stack
(
#checkpoint_fn = get_checkpoint_fn()
m
,
#blocks = [
z
,
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
msa_mask
=
msa_mask
,
#]
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
#def dodo(b, *args):
_mask_trans
=
_mask_trans
,
# torch.cuda.empty_cache()
)
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
return
z
return
z
openfold/model/heads.py
View file @
cfd0fc6e
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
compute_plddt
,
compute_plddt
,
compute_tm
,
compute_tm
,
...
@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module):
...
@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module):
self
.
c_in
=
c_in
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm
=
LayerNorm
(
self
.
c_in
)
self
.
linear_1
=
Linear
(
self
.
c_in
,
self
.
c_hidden
,
init
=
"relu"
)
self
.
linear_1
=
Linear
(
self
.
c_in
,
self
.
c_hidden
,
init
=
"relu"
)
self
.
linear_2
=
Linear
(
self
.
c_hidden
,
self
.
c_hidden
,
init
=
"relu"
)
self
.
linear_2
=
Linear
(
self
.
c_hidden
,
self
.
c_hidden
,
init
=
"relu"
)
...
...
openfold/model/model.py
View file @
cfd0fc6e
...
@@ -134,7 +134,7 @@ class AlphaFold(nn.Module):
...
@@ -134,7 +134,7 @@ class AlphaFold(nn.Module):
inf
=
self
.
config
.
template
.
inf
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
,
**
self
.
config
.
template
.
distogram
,
)
)
.
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
single_template_embeds
.
update
({
"pair"
:
t
})
...
@@ -149,7 +149,7 @@ class AlphaFold(nn.Module):
...
@@ -149,7 +149,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_z]
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
),
pair_mask
.
unsqueeze
(
-
3
)
.
to
(
dtype
=
z
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -158,7 +158,7 @@ class AlphaFold(nn.Module):
...
@@ -158,7 +158,7 @@ class AlphaFold(nn.Module):
t
=
self
.
template_pointwise_att
(
t
=
self
.
template_pointwise_att
(
t
,
t
,
z
,
z
,
template_mask
=
batch
[
"template_mask"
],
template_mask
=
batch
[
"template_mask"
]
.
to
(
dtype
=
z
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
)
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
@@ -175,6 +175,12 @@ class AlphaFold(nn.Module):
...
@@ -175,6 +175,12 @@ class AlphaFold(nn.Module):
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
# This needs to be done manually for DeepSpeed's sake
dtype
=
next
(
self
.
parameters
()).
dtype
for
k
in
feats
:
if
(
feats
[
k
].
dtype
==
torch
.
float32
):
feats
[
k
]
=
feats
[
k
].
to
(
dtype
=
dtype
)
# Grab some data about the input
# Grab some data about the input
batch_dims
=
feats
[
"target_feat"
].
shape
[:
-
2
]
batch_dims
=
feats
[
"target_feat"
].
shape
[:
-
2
]
no_batch_dims
=
len
(
batch_dims
)
no_batch_dims
=
len
(
batch_dims
)
...
@@ -217,7 +223,9 @@ class AlphaFold(nn.Module):
...
@@ -217,7 +223,9 @@ class AlphaFold(nn.Module):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
# m_1_prev_emb: [*, N, C_m]
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
# z_prev_emb: [*, N, N, C_z]
...
@@ -246,34 +254,32 @@ class AlphaFold(nn.Module):
...
@@ -246,34 +254,32 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
template_mask
=
feats
[
"template_mask"
]
template_feats
=
{
if
(
torch
.
any
(
template_mask
)):
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
template_feats
=
{
}
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
template_embeds
=
self
.
embed_templates
(
}
template_feats
,
template_embeds
=
self
.
embed_templates
(
z
,
template_feats
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
z
,
no_batch_dims
,
pair_mask
,
)
no_batch_dims
,
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
self
.
config
.
template
.
embed_angles
:
if
self
.
config
.
template
.
embed_angles
:
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
[
m
,
template_embeds
[
"template_angle_embedding"
]],
dim
=-
3
dim
=-
3
)
)
# [*, S, N]
# [*, S, N]
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
dim
=-
2
dim
=-
2
)
)
# Embed extra MSA features + merge with pairwise embeddings
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
config
.
extra_msa
.
enabled
:
...
@@ -284,9 +290,9 @@ class AlphaFold(nn.Module):
...
@@ -284,9 +290,9 @@ class AlphaFold(nn.Module):
z
=
self
.
extra_msa_stack
(
z
=
self
.
extra_msa_stack
(
a
,
a
,
z
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
],
msa_mask
=
feats
[
"extra_msa_mask"
]
.
to
(
dtype
=
a
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
)
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -297,8 +303,8 @@ class AlphaFold(nn.Module):
...
@@ -297,8 +303,8 @@ class AlphaFold(nn.Module):
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
s
=
self
.
evoformer
(
m
,
m
,
z
,
z
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
)
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
)
,
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -312,7 +318,7 @@ class AlphaFold(nn.Module):
...
@@ -312,7 +318,7 @@ class AlphaFold(nn.Module):
s
,
s
,
z
,
z
,
feats
[
"aatype"
],
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
],
mask
=
feats
[
"seq_mask"
]
.
to
(
dtype
=
s
.
dtype
)
,
)
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
...
@@ -336,7 +342,9 @@ class AlphaFold(nn.Module):
...
@@ -336,7 +342,9 @@ class AlphaFold(nn.Module):
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
extra_msa_stack
.
stack
.
blocks_per_ckpt
=
None
for
b
in
self
.
extra_msa_stack
.
blocks
:
b
.
ckpt
=
False
def
_enable_activation_checkpointing
(
self
):
def
_enable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
...
@@ -345,9 +353,9 @@ class AlphaFold(nn.Module):
...
@@ -345,9 +353,9 @@ class AlphaFold(nn.Module):
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
config
.
evoformer_stack
.
blocks_per_ckpt
self
.
config
.
evoformer_stack
.
blocks_per_ckpt
)
)
self
.
extra_msa_stack
.
stack
.
blocks_per_ckpt
=
(
self
.
config
.
extra_msa
.
extra_msa_stack
.
blocks
_per_ckpt
for
b
in
self
.
extra_msa_stack
.
blocks
:
)
b
.
ckpt
=
self
.
config
.
extra_msa
.
extra_msa_stack
.
ckpt
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
"""
"""
...
...
openfold/model/msa.py
View file @
cfd0fc6e
...
@@ -16,9 +16,16 @@
...
@@ -16,9 +16,16 @@
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Tuple
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalAttention
from
openfold.model.primitives
import
(
Linear
,
LayerNorm
,
Attention
,
GlobalAttention
,
_attention_chunked_trainable
,
)
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
...
@@ -61,16 +68,16 @@ class MSAAttention(nn.Module):
...
@@ -61,16 +68,16 @@ class MSAAttention(nn.Module):
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
inf
=
inf
self
.
inf
=
inf
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_z
=
None
self
.
layer_norm_z
=
None
self
.
linear_z
=
None
self
.
linear_z
=
None
if
self
.
pair_bias
:
if
self
.
pair_bias
:
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
self
.
linear_z
=
Linear
(
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
)
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
)
...
@@ -83,32 +90,16 @@ class MSAAttention(nn.Module):
...
@@ -83,32 +90,16 @@ class MSAAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
chunk_layer
(
return
chunk_layer
(
self
.
mha
,
self
.
mha
,
{
"q_x"
:
m
,
"k
_x"
:
m
,
"
v_x"
:
m
,
"biases"
:
biases
},
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
},
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
)
def
forward
(
self
,
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_seq, N_res, C_m]
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
layer_norm_m
(
m
)
...
@@ -120,16 +111,14 @@ class MSAAttention(nn.Module):
...
@@ -120,16 +111,14 @@ class MSAAttention(nn.Module):
)
)
# [*, N_seq, 1, 1, N_res]
# [*, N_seq, 1, 1, N_res]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
mask_
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# This step simply returns a larger view of the bias, and does not
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
# [*, N_seq, no_heads, N_res, N_res]
bias
=
bias
.
expand
(
#bias = bias.expand(
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
#)
biases
=
[
bias
]
if
(
self
.
pair_bias
and
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
z
is
not
None
and
# For the
...
@@ -138,19 +127,98 @@ class MSAAttention(nn.Module):
...
@@ -138,19 +127,98 @@ class MSAAttention(nn.Module):
):
):
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
z
)
# [*, N_res, N_res, no_heads]
# [*, N_res, N_res, no_heads]
z
=
self
.
linear_z
(
z
)
z
=
self
.
linear_z
(
z
)
# [*, 1, no_heads, N_res, N_res]
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
return
m
,
mask_bias
,
z
@
torch
.
jit
.
ignore
def
_chunked_msa_attn
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
checkpoint
:
bool
,
)
->
torch
.
Tensor
:
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
checkpoint_fn
=
get_checkpoint_fn
()
if
(
torch
.
is_grad_enabled
()
and
checkpoint
):
m
,
q
,
k
,
v
,
mask_bias
,
z
=
checkpoint_fn
(
_get_qkv
,
m
,
z
)
else
:
m
,
q
,
k
,
v
,
mask_bias
,
z
=
_get_qkv
(
m
,
z
)
o
=
_attention_chunked_trainable
(
query
=
q
,
key
=
k
,
value
=
v
,
biases
=
[
mask_bias
,
z
],
chunk_size
=
chunk_logits
,
chunk_dim
=
MSA_DIM
,
checkpoint
=
checkpoint
,
)
if
(
torch
.
is_grad_enabled
()
and
checkpoint
):
# Storing an additional m here is far from ideal
m
=
checkpoint_fn
(
self
.
mha
.
_wrap_up
,
o
,
m
)
else
:
m
=
self
.
mha
.
_wrap_up
(
o
,
m
)
return
m
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
else
:
else
:
m
=
self
.
mha
(
q_x
=
m
,
k_x
=
m
,
v_x
=
m
,
biases
=
biases
)
m
=
self
.
mha
(
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
)
return
m
return
m
...
...
openfold/model/pair_transition.py
View file @
cfd0fc6e
...
@@ -17,7 +17,7 @@ from typing import Optional
...
@@ -17,7 +17,7 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
chunk_layer
...
@@ -40,7 +40,7 @@ class PairTransition(nn.Module):
...
@@ -40,7 +40,7 @@ class PairTransition(nn.Module):
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
n
=
n
self
.
n
=
n
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm
=
LayerNorm
(
self
.
c_z
)
self
.
linear_1
=
Linear
(
self
.
c_z
,
self
.
n
*
self
.
c_z
,
init
=
"relu"
)
self
.
linear_1
=
Linear
(
self
.
c_z
,
self
.
n
*
self
.
c_z
,
init
=
"relu"
)
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
...
...
openfold/model/primitives.py
View file @
cfd0fc6e
...
@@ -13,14 +13,17 @@
...
@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
numpy
as
np
import
deepspeed
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
...
@@ -164,6 +167,135 @@ class Linear(nn.Linear):
...
@@ -164,6 +167,135 @@ class Linear(nn.Linear):
raise
ValueError
(
"Invalid init string."
)
raise
ValueError
(
"Invalid init string."
)
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
eps
=
1e-5
):
super
(
LayerNorm
,
self
).
__init__
()
self
.
c_in
=
(
c_in
,)
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
c_in
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
c_in
))
def
forward
(
self
,
x
):
d
=
x
.
dtype
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
x
,
self
.
c_in
,
self
.
weight
.
to
(
dtype
=
d
),
self
.
bias
.
to
(
dtype
=
d
),
self
.
eps
)
else
:
out
=
nn
.
functional
.
layer_norm
(
x
,
self
.
c_in
,
self
.
weight
,
self
.
bias
,
self
.
eps
,
)
return
out
@
torch
.
jit
.
ignore
def
softmax
(
t
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d
=
t
.
dtype
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
else
:
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
return
s
#@torch.jit.script
def
_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# [*, H, Q, C_hidden]
query
=
permute_final_dims
(
query
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
key
=
permute_final_dims
(
key
,
(
1
,
2
,
0
))
# [*, H, V, C_hidden]
value
=
permute_final_dims
(
value
,
(
1
,
0
,
2
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
query
,
key
)
for
b
in
biases
:
a
+=
b
a
=
softmax
(
a
,
-
1
)
# [*, H, Q, C_hidden]
a
=
torch
.
matmul
(
a
,
value
)
# [*, Q, H, C_hidden]
a
=
a
.
transpose
(
-
2
,
-
3
)
return
a
@
torch
.
jit
.
ignore
def
_attention_chunked_trainable
(
query
,
key
,
value
,
biases
,
chunk_size
,
chunk_dim
,
checkpoint
,
):
if
(
checkpoint
and
len
(
biases
)
>
2
):
raise
ValueError
(
"Checkpointed version permits only permits two bias terms"
)
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
bs
=
[
b
for
b
in
[
b1
,
b2
]
if
b
is
not
None
]
return
_attention
(
q
,
k
,
v
,
bs
)
o_chunks
=
[]
checkpoint_fn
=
get_checkpoint_fn
()
count
=
query
.
shape
[
chunk_dim
]
for
start
in
range
(
0
,
count
,
chunk_size
):
end
=
start
+
chunk_size
idx
=
[
slice
(
None
)]
*
len
(
query
.
shape
)
idx
[
chunk_dim
]
=
slice
(
start
,
end
)
idx_tup
=
tuple
(
idx
)
q_chunk
=
query
[
idx_tup
]
k_chunk
=
key
[
idx_tup
]
v_chunk
=
value
[
idx_tup
]
def
_slice_bias
(
b
):
idx
[
chunk_dim
]
=
(
slice
(
start
,
end
)
if
b
.
shape
[
chunk_dim
]
!=
1
else
slice
(
None
)
)
return
b
[
tuple
(
idx
)]
if
(
checkpoint
):
bias_1_chunk
,
bias_2_chunk
=
[
_slice_bias
(
b
)
if
b
is
not
None
else
None
for
b
in
(
biases
+
[
None
,
None
])[:
2
]
]
o_chunk
=
checkpoint_fn
(
_checkpointable_attention
,
q_chunk
,
k_chunk
,
v_chunk
,
bias_1_chunk
,
bias_2_chunk
)
else
:
bias_chunks
=
[
_slice_bias
(
b
)
for
b
in
biases
]
o_chunk
=
_attention
(
q_chunk
,
k_chunk
,
v_chunk
,
bias_chunks
)
o_chunks
.
append
(
o_chunk
)
o
=
torch
.
cat
(
o_chunks
,
dim
=
chunk_dim
)
return
o
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""
"""
Standard multi-head attention using AlphaFold's default layer
Standard multi-head attention using AlphaFold's default layer
...
@@ -225,66 +357,34 @@ class Attention(nn.Module):
...
@@ -225,66 +357,34 @@ class Attention(nn.Module):
)
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
def
_prep_qkv
(
self
,
self
,
q_x
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
k_x
:
torch
.
Tensor
,
)
->
Tuple
[
v_x
:
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
]:
)
->
torch
.
Tensor
:
"""
Args:
q_x:
[*, Q, C_q] query data
k_x:
[*, K, C_k] key data
v_x:
[*, V, C_v] value data
Returns
[*, Q, C_q] attention update
"""
# [*, Q/K/V, H * C_hidden]
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
k
=
self
.
linear_k
(
k
v
_x
)
v
=
self
.
linear_v
(
v_x
)
v
=
self
.
linear_v
(
k
v_x
)
# [*, Q/K, H, C_hidden]
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, H, Q, C_hidden]
q
/=
math
.
sqrt
(
self
.
c_hidden
)
q
=
permute_final_dims
(
q
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
k
=
permute_final_dims
(
k
,
(
1
,
2
,
0
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
q
,
k
)
del
q
,
k
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
*=
norm
if
biases
is
not
None
:
for
b
in
biases
:
a
+=
b
a
=
self
.
softmax
(
a
)
return
q
,
k
,
v
# [*, H, V, C_hidden]
def
_wrap_up
(
self
,
v
=
permute_final_dims
(
v
,
(
1
,
0
,
2
))
o
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
# [*, H, Q, C_hidden]
)
->
torch
.
Tensor
:
o
=
torch
.
matmul
(
a
,
v
)
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
if
(
self
.
linear_g
is
not
None
):
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
o
=
o
*
g
...
@@ -297,6 +397,56 @@ class Attention(nn.Module):
...
@@ -297,6 +397,56 @@ class Attention(nn.Module):
return
o
return
o
def
forward
(
self
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_lma:
Whether to use low-memory attention
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if
(
biases
is
None
):
biases
=
[]
if
(
use_lma
and
(
q_chunk_size
is
None
or
kv_chunk_size
is
None
)):
raise
ValueError
(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
if
(
use_lma
):
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
self
.
_wrap_up
(
o
,
q_x
)
return
o
class
GlobalAttention
(
nn
.
Module
):
class
GlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
,
eps
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
,
eps
):
...
@@ -322,7 +472,6 @@ class GlobalAttention(nn.Module):
...
@@ -322,7 +472,6 @@ class GlobalAttention(nn.Module):
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
# [*, N_res, C_in]
...
@@ -348,7 +497,7 @@ class GlobalAttention(nn.Module):
...
@@ -348,7 +497,7 @@ class GlobalAttention(nn.Module):
)
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
+=
bias
a
=
self
.
softmax
(
a
)
a
=
softmax
(
a
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
...
@@ -374,14 +523,13 @@ class GlobalAttention(nn.Module):
...
@@ -374,14 +523,13 @@ class GlobalAttention(nn.Module):
return
m
return
m
@
torch
.
jit
.
script
def
_lma
(
def
_lma
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
q_chunk_size
:
int
,
q_chunk_size
:
int
,
kv_chunk_size
:
int
kv_chunk_size
:
int
,
):
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
...
@@ -389,34 +537,34 @@ def _lma(
...
@@ -389,34 +537,34 @@ def _lma(
o
=
q
.
new_zeros
(
q
.
shape
)
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
big
_bias_chunks
=
[
large
_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
]
maxes
=
[]
maxes
=
[]
weights
=
[]
weights
=
[]
values
=
[]
values
=
[]
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
small_bias_chunks
=
[
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
big
_bias_chunks
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large
_bias_chunks
]
]
a
=
torch
.
einsum
(
a
=
torch
.
einsum
(
"...qhd,...khd->...hqk"
,
q_chunk
,
k_chunk
"...qhd,...khd->...hqk"
,
q_chunk
,
k_chunk
,
)
)
for
b
in
small_bias_chunks
:
for
b
in
small_bias_chunks
:
a
+=
b
a
+=
b
a
=
a
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
.
detach
()
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...vhf,...qhv->...qhf"
,
v_chunk
,
exp_a
)
exp_v
=
torch
.
einsum
(
"...vhf,...qhv->...qhf"
,
v_chunk
,
exp_a
)
maxes
.
append
(
max_a
.
squeeze
(
-
1
))
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
values
.
append
(
exp_v
)
values
.
append
(
exp_v
)
...
@@ -437,111 +585,3 @@ def _lma(
...
@@ -437,111 +585,3 @@ def _lma(
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
return
o
return
o
class
LowMemoryAttention
(
nn
.
Module
):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def
__init__
(
self
,
c_q
:
int
,
c_k
:
int
,
c_v
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
gating
:
bool
=
True
,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super
().
__init__
()
self
.
c_q
=
c_q
self
.
c_k
=
c_k
self
.
c_v
=
c_v
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
gating
=
gating
self
.
linear_q
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_k
=
Linear
(
self
.
c_k
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_v
=
Linear
(
self
.
c_v
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_o
=
Linear
(
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
)
if
self
.
gating
:
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
q_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
):
if
(
biases
is
None
):
biases
=
[]
else
:
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
k_x
.
shape
[
-
2
],))
for
b
in
biases
]
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
v
=
self
.
linear_v
(
v_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
/
math
.
sqrt
(
q
.
shape
[
-
1
])
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
if
self
.
gating
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
return
o
openfold/model/structure_module.py
View file @
cfd0fc6e
...
@@ -18,7 +18,7 @@ import torch
...
@@ -18,7 +18,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
openfold.model.primitives
import
Linear
,
ipa_point_weights_init_
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_to_rigid_group
,
...
@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
)
a
=
a
*
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*
=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
=
a
+
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+
=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
...
@@ -323,7 +323,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -323,7 +323,7 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
))
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
))
a
=
a
+
pt_att
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
...
@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module):
...
@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module):
# Compute output
# Compute output
################
################
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
.
transpose
(
-
2
,
-
3
)).
transpose
(
-
2
,
-
3
)
o
=
torch
.
matmul
(
a
,
v
.
transpose
(
-
2
,
-
3
).
to
(
dtype
=
a
.
dtype
)
).
transpose
(
-
2
,
-
3
)
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
o
=
flatten_final_dims
(
o
,
2
)
...
@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module):
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
)
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
)
)
# [*, N_res, H * C_z]
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
...
@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module):
s
=
self
.
linear_out
(
s
=
self
.
linear_out
(
torch
.
cat
(
torch
.
cat
(
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
)
)
.
to
(
dtype
=
z
.
dtype
)
)
)
return
s
return
s
...
@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module):
...
@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module):
self
.
layers
.
append
(
l
)
self
.
layers
.
append
(
l
)
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c
)
self
.
layer_norm
=
LayerNorm
(
self
.
c
)
def
forward
(
self
,
s
):
def
forward
(
self
,
s
):
for
l
in
self
.
layers
:
for
l
in
self
.
layers
:
...
@@ -534,8 +536,8 @@ class StructureModule(nn.Module):
...
@@ -534,8 +536,8 @@ class StructureModule(nn.Module):
self
.
atom_mask
=
None
self
.
atom_mask
=
None
self
.
lit_positions
=
None
self
.
lit_positions
=
None
self
.
layer_norm_s
=
nn
.
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_s
=
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
linear_in
=
Linear
(
self
.
c_s
,
self
.
c_s
)
self
.
linear_in
=
Linear
(
self
.
c_s
,
self
.
c_s
)
...
@@ -551,7 +553,7 @@ class StructureModule(nn.Module):
...
@@ -551,7 +553,7 @@ class StructureModule(nn.Module):
)
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
layer_norm_ipa
=
nn
.
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_ipa
=
LayerNorm
(
self
.
c_s
)
self
.
transition
=
StructureModuleTransition
(
self
.
transition
=
StructureModuleTransition
(
self
.
c_s
,
self
.
c_s
,
...
...
openfold/model/template.py
View file @
cfd0fc6e
...
@@ -19,7 +19,7 @@ from typing import Optional, List
...
@@ -19,7 +19,7 @@ from typing import Optional, List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.model.dropout
import
(
from
openfold.model.dropout
import
(
DropoutRowwise
,
DropoutRowwise
,
DropoutColumnwise
,
DropoutColumnwise
,
...
@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
z
,
"q_x"
:
z
,
"k_x"
:
t
,
"kv_x"
:
t
,
"v_x"
:
t
,
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
...
@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module):
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
else
:
else
:
z
=
self
.
mha
(
q_x
=
z
,
k
_x
=
t
,
v_x
=
t
,
biases
=
biases
)
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
z
=
z
.
squeeze
(
-
2
)
...
@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module):
...
@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module):
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
self
.
layer_norm
=
nn
.
LayerNorm
(
c_t
)
self
.
layer_norm
=
LayerNorm
(
c_t
)
def
forward
(
def
forward
(
self
,
self
,
...
...
openfold/model/triangular_attention.py
View file @
cfd0fc6e
...
@@ -13,14 +13,14 @@
...
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partialmethod
from
functools
import
partialmethod
,
partial
import
math
import
math
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
...
@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module):
...
@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module):
self
.
starting
=
starting
self
.
starting
=
starting
self
.
inf
=
inf
self
.
inf
=
inf
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm
=
LayerNorm
(
self
.
c_in
)
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
...
@@ -65,12 +65,11 @@ class TriangleAttention(nn.Module):
...
@@ -65,12 +65,11 @@ class TriangleAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
x
,
"q_x"
:
x
,
"k_x"
:
x
,
"kv_x"
:
x
,
"v_x"
:
x
,
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
self
.
mha
,
partial
(
self
.
mha
)
,
mha_inputs
,
mha_inputs
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
...
@@ -116,7 +115,7 @@ class TriangleAttention(nn.Module):
...
@@ -116,7 +115,7 @@ class TriangleAttention(nn.Module):
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
else
:
else
:
x
=
self
.
mha
(
q_x
=
x
,
k
_x
=
x
,
v_x
=
x
,
biases
=
biases
)
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
)
if
not
self
.
starting
:
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/triangular_multiplicative_update.py
View file @
cfd0fc6e
...
@@ -19,7 +19,7 @@ from typing import Optional
...
@@ -19,7 +19,7 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
permute_final_dims
from
openfold.utils.tensor_utils
import
permute_final_dims
...
@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module):
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
self
.
layer_norm_in
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_in
=
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_out
=
nn
.
LayerNorm
(
self
.
c_hidden
)
self
.
layer_norm_out
=
LayerNorm
(
self
.
c_hidden
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
...
...
openfold/utils/checkpointing.py
View file @
cfd0fc6e
...
@@ -15,17 +15,27 @@
...
@@ -15,17 +15,27 @@
import
deepspeed
import
deepspeed
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
from
typing
import
Any
,
Tuple
,
List
,
Callable
,
Optional
BLOCK_ARG
=
Any
BLOCK_ARG
=
Any
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
def
get_checkpoint_fn
():
if
(
deepspeed
.
checkpointing
.
is_configured
()):
checkpoint
=
deepspeed
.
checkpointing
.
checkpoint
else
:
checkpoint
=
torch
.
utils
.
checkpoint
.
checkpoint
return
checkpoint
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
checkpoint_blocks
(
def
checkpoint_blocks
(
blocks
:
List
[
Callable
],
blocks
:
List
[
Callable
],
args
:
BLOCK_ARGS
,
args
:
BLOCK_ARGS
,
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
Optional
[
int
]
,
)
->
BLOCK_ARGS
:
)
->
BLOCK_ARGS
:
"""
"""
Chunk a list of blocks and run each chunk with activation
Chunk a list of blocks and run each chunk with activation
...
@@ -68,10 +78,7 @@ def checkpoint_blocks(
...
@@ -68,10 +78,7 @@ def checkpoint_blocks(
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
if
(
deepspeed
.
checkpointing
.
is_configured
()):
checkpoint
=
get_checkpoint_fn
()
checkpoint
=
deepspeed
.
checkpointing
.
checkpoint
else
:
checkpoint
=
torch
.
utils
.
checkpoint
.
checkpoint
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
...
...
openfold/utils/import_weights.py
View file @
cfd0fc6e
...
@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
b
.
msa_att_row
b
.
msa_att_row
),
),
col_att_name
:
msa_col_att_params
,
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
outer_product_mean
),
"outer_product_mean"
:
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
tri_mul_out
),
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
tri_mul_in
),
"triangle_multiplication_outgoing"
:
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
tri_att_start
),
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
tri_att_end
),
"triangle_multiplication_incoming"
:
"pair_transition"
:
PairTransitionParams
(
b
.
pair_transition
),
TriMulInParams
(
b
.
core
.
tri_mul_in
),
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
}
}
return
d
return
d
...
@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
)
)
ems_blocks
=
model
.
extra_msa_stack
.
stack
.
blocks
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks
=
model
.
evoformer
.
blocks
...
...
openfold/utils/loss.py
View file @
cfd0fc6e
...
@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels):
...
@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels):
def
sigmoid_cross_entropy
(
logits
,
labels
):
def
sigmoid_cross_entropy
(
logits
,
labels
):
log_p
=
torch
.
nn
.
functional
.
log
sigmoid
(
logits
)
log_p
=
torch
.
log
(
torch
.
sigmoid
(
logits
)
)
log_not_p
=
torch
.
nn
.
functional
.
log
sigmoid
(
-
logits
)
log_not_p
=
torch
.
log
(
torch
.
sigmoid
(
-
logits
)
)
loss
=
-
labels
*
log_p
-
(
1
-
labels
)
*
log_not_p
loss
=
-
labels
*
log_p
-
(
1
-
labels
)
*
log_not_p
return
loss
return
loss
...
@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
...
@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
return
pred_lddt_ca
*
100
return
pred_lddt_ca
*
100
def
lddt_loss
(
def
lddt
(
logits
:
torch
.
Tensor
,
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.0
,
cutoff
:
float
=
15.0
,
no_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
**
kwargs
,
per_residue
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n
=
all_atom_mask
.
shape
[
-
2
]
n
=
all_atom_mask
.
shape
[
-
2
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
dmat_true
=
torch
.
sqrt
(
dmat_true
=
torch
.
sqrt
(
eps
eps
+
torch
.
sum
(
+
torch
.
sum
(
...
@@ -389,8 +378,63 @@ def lddt_loss(
...
@@ -389,8 +378,63 @@ def lddt_loss(
)
)
score
=
score
*
0.25
score
=
score
*
0.25
norm
=
1.0
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
dims
=
(
-
1
,)
if
per_residue
else
(
-
2
,
-
1
)
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
norm
=
1.0
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=
dims
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=
dims
))
return
score
def
lddt_ca
(
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
cutoff
:
float
=
15.0
,
eps
:
float
=
1e-10
,
per_residue
:
bool
=
True
,
)
->
torch
.
Tensor
:
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
return
lddt
(
all_atom_pred_pos
,
all_atom_positions
,
all_atom_mask
,
cutoff
=
cutoff
,
eps
=
eps
,
per_residue
=
per_residue
,
)
def
lddt_loss
(
logits
:
torch
.
Tensor
,
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.0
,
no_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
eps
:
float
=
1e-10
,
**
kwargs
,
)
->
torch
.
Tensor
:
n
=
all_atom_mask
.
shape
[
-
2
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
score
=
lddt
(
all_atom_pred_pos
,
all_atom_positions
,
all_atom_mask
,
cutoff
=
cutoff
,
eps
=
eps
)
score
=
score
.
detach
()
score
=
score
.
detach
()
...
@@ -1462,7 +1506,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1462,7 +1506,7 @@ class AlphaFoldLoss(nn.Module):
self
.
config
=
config
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
def
forward
(
self
,
out
,
batch
):
if
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
:
if
"violation"
not
in
out
.
keys
():
out
[
"violation"
]
=
find_structural_violations
(
out
[
"violation"
]
=
find_structural_violations
(
batch
,
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
out
[
"sm"
][
"positions"
][
-
1
],
...
@@ -1509,22 +1553,26 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1509,22 +1553,26 @@ class AlphaFoldLoss(nn.Module):
out
[
"violation"
],
out
[
"violation"
],
**
batch
,
**
batch
,
),
),
"tm"
:
lambda
:
tm_loss
(
}
if
(
self
.
config
.
tm
.
enabled
):
loss_fns
[
"tm"
]
=
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
)
}
cum_loss
=
0.
cum_loss
=
0.
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
weight
=
self
.
config
[
loss_name
].
weight
if
weight
:
loss
=
loss_fn
()
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
cum_loss
=
cum_loss
+
weight
*
loss
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
seq_len
=
torch
.
mean
(
batch
[
"seq_length"
].
float
())
crop_len
=
batch
[
"aatype"
].
shape
[
-
1
]
cum_loss
=
cum_loss
*
torch
.
sqrt
(
min
(
seq_len
,
crop_len
))
# Scale the loss by the square root of the minimum of the crop size and
# Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9.
# the (average) sequence length. See subsection 1.9.
...
...
openfold/utils/rigid_utils.py
View file @
cfd0fc6e
...
@@ -26,7 +26,7 @@ def rot_matmul(
...
@@ -26,7 +26,7 @@ def rot_matmul(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Performs matrix multiplication of two rotation matrix tensors. Written
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid
transfer to low-precision tensor cores
.
out by hand to avoid
AMP downcasting
.
Args:
Args:
a: [*, 3, 3] left multiplicand
a: [*, 3, 3] left multiplicand
...
@@ -86,7 +86,7 @@ def rot_vec_mul(
...
@@ -86,7 +86,7 @@ def rot_vec_mul(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
Applies a rotation to a vector. Written out by hand to avoid transfer
to
low-precision tensor cores
.
to
avoid AMP downcasting
.
Args:
Args:
r: [*, 3, 3] rotation matrices
r: [*, 3, 3] rotation matrices
...
@@ -323,6 +323,12 @@ class Rotation:
...
@@ -323,6 +323,12 @@ class Rotation:
"Incorrectly shaped rotation matrix or quaternion"
"Incorrectly shaped rotation matrix or quaternion"
)
)
# Force full-precision
if
(
quats
is
not
None
):
quats
=
quats
.
to
(
dtype
=
torch
.
float32
)
if
(
rot_mats
is
not
None
):
rot_mats
=
rot_mats
.
to
(
dtype
=
torch
.
float32
)
if
(
quats
is
not
None
and
normalize_quats
):
if
(
quats
is
not
None
and
normalize_quats
):
quats
=
quats
/
torch
.
linalg
.
norm
(
quats
,
dim
=-
1
,
keepdim
=
True
)
quats
=
quats
/
torch
.
linalg
.
norm
(
quats
,
dim
=-
1
,
keepdim
=
True
)
...
@@ -857,6 +863,9 @@ class Rigid:
...
@@ -857,6 +863,9 @@ class Rigid:
(
rots
.
device
!=
trans
.
device
)):
(
rots
.
device
!=
trans
.
device
)):
raise
ValueError
(
"Rots and trans incompatible"
)
raise
ValueError
(
"Rots and trans incompatible"
)
# Force full precision. Happens to the rotations automatically.
trans
=
trans
.
to
(
dtype
=
torch
.
float32
)
self
.
_rots
=
rots
self
.
_rots
=
rots
self
.
_trans
=
trans
self
.
_trans
=
trans
...
...
scripts/generate_mmcif_cache.py
deleted
100644 → 0
View file @
c9e0f894
import
argparse
from
functools
import
partial
import
logging
from
multiprocessing
import
Pool
import
os
import
sys
import
json
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
from
tqdm
import
tqdm
from
openfold.data.mmcif_parsing
import
parse
def
parse_file
(
f
,
args
):
with
open
(
os
.
path
.
join
(
args
.
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
info
(
f
"Could not parse
{
f
}
. Skipping..."
)
return
{}
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
return
{
file_id
:
local_data
}
def
main
(
args
):
files
=
[
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
)
if
".cif"
in
f
]
fn
=
partial
(
parse_file
,
args
=
args
)
data
=
{}
with
Pool
(
processes
=
args
.
no_workers
)
as
p
:
with
tqdm
(
total
=
len
(
files
))
as
pbar
:
for
d
in
p
.
imap_unordered
(
fn
,
files
,
chunksize
=
args
.
chunksize
):
data
.
update
(
d
)
pbar
.
update
()
with
open
(
args
.
output_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
,
indent
=
4
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
,
help
=
"Directory containing mmCIF files"
)
parser
.
add_argument
(
"output_path"
,
type
=
str
,
help
=
"Path for .json output"
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of workers to use for parsing"
)
parser
.
add_argument
(
"--chunksize"
,
type
=
int
,
default
=
10
,
help
=
"How many files should be distributed to each worker at a time"
)
args
=
parser
.
parse_args
()
main
(
args
)
scripts/precompute_alignments.py
View file @
cfd0fc6e
import
argparse
import
argparse
from
functools
import
partial
import
json
import
logging
import
logging
import
os
import
os
import
threading
from
multiprocessing
import
cpu_count
from
shutil
import
copyfile
import
tempfile
import
tempfile
import
openfold.data.mmcif_parsing
as
mmcif_parsing
import
openfold.data.mmcif_parsing
as
mmcif_parsing
...
@@ -10,30 +15,58 @@ from openfold.np import protein, residue_constants
...
@@ -10,30 +15,58 @@ from openfold.np import protein, residue_constants
from
utils
import
add_data_args
from
utils
import
add_data_args
#python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ data/uniref90/uniref90.fasta data/mgnify/mgy_clusters_2018_12.fa data/pdb70/pdb70 data/pdb_mmcif/mmcif_files/ data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt --cpus 16 --jackhmmer_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/jackhmmer --hhblits_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhblits --hhsearch_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhsearch --kalign_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/kalign
logging
.
basicConfig
(
level
=
logging
.
DEBU
G
)
logging
.
basicConfig
(
level
=
logging
.
WARNIN
G
)
def
main
(
args
):
def
run_seq_group_alignments
(
seq_groups
,
alignment_runner
,
args
):
# Build the alignment tool runner
dirs
=
set
(
os
.
listdir
(
args
.
output_dir
))
alignment_runner
=
AlignmentRunner
(
for
seq
,
names
in
seq_groups
:
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
first_name
=
names
[
0
]
hhblits_binary_path
=
args
.
hhblits_binary_path
,
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
first_name
)
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
os
.
makedirs
(
alignment_dir
,
exist_ok
=
True
)
mgnify_database_path
=
args
.
mgnify_database_path
,
# try:
bfd_database_path
=
args
.
bfd_database_path
,
# os.makedirs(alignment_dir)
uniclust30_database_path
=
args
.
uniclust30_database_path
,
# except Exception as e:
pdb70_database_path
=
args
.
pdb70_database_path
,
# logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
use_small_bfd
=
args
.
bfd_database_path
is
None
,
# continue
no_cpus
=
args
.
cpus
,
)
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
fp
.
write
(
f
'>query
\n
{
seq
}
'
)
try
:
alignment_runner
.
run
(
fasta_path
,
alignment_dir
)
except
:
logging
.
warning
(
f
"Failed to run alignments for
{
first_name
}
. Skipping..."
)
os
.
remove
(
fasta_path
)
os
.
rmdir
(
alignment_dir
)
continue
os
.
remove
(
fasta_path
)
for
f
in
os
.
listdir
(
args
.
input_dir
):
for
name
in
names
[
1
:]:
#if(name in dirs):
# logging.warning(
# f'{name} has already been processed. Skipping...'
# )
# continue
cp_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
os
.
makedirs
(
cp_dir
,
exist_ok
=
True
)
for
f
in
os
.
listdir
(
alignment_dir
):
copyfile
(
os
.
path
.
join
(
alignment_dir
,
f
),
os
.
path
.
join
(
cp_dir
,
f
))
def
parse_and_align
(
files
,
alignment_runner
,
args
):
for
f
in
files
:
path
=
os
.
path
.
join
(
args
.
input_dir
,
f
)
path
=
os
.
path
.
join
(
args
.
input_dir
,
f
)
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
seq
s
=
{}
seq
_group_dict
=
{}
if
(
f
.
endswith
(
'.cif'
)):
if
(
f
.
endswith
(
'.cif'
)):
with
open
(
path
,
'r'
)
as
fp
:
with
open
(
path
,
'r'
)
as
fp
:
mmcif_str
=
fp
.
read
()
mmcif_str
=
fp
.
read
()
...
@@ -47,9 +80,10 @@ def main(args):
...
@@ -47,9 +80,10 @@ def main(args):
else
:
else
:
continue
continue
mmcif
=
mmcif
.
mmcif_object
mmcif
=
mmcif
.
mmcif_object
for
k
,
v
in
mmcif
.
chain_to_seqres
.
items
():
for
chain_letter
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
chain_id
=
'_'
.
join
([
file_id
,
k
])
chain_id
=
'_'
.
join
([
file_id
,
chain_letter
])
seqs
[
chain_id
]
=
v
l
=
seq_group_dict
.
setdefault
(
seq
,
[])
l
.
append
(
chain_id
)
elif
(
f
.
endswith
(
'.fasta'
)
or
f
.
endswith
(
'.fa'
)):
elif
(
f
.
endswith
(
'.fasta'
)
or
f
.
endswith
(
'.fa'
)):
with
open
(
path
,
'r'
)
as
fp
:
with
open
(
path
,
'r'
)
as
fp
:
fasta_str
=
fp
.
read
()
fasta_str
=
fp
.
read
()
...
@@ -61,7 +95,7 @@ def main(args):
...
@@ -61,7 +95,7 @@ def main(args):
else
:
else
:
logging
.
warning
(
msg
)
logging
.
warning
(
msg
)
input_sequence
=
input_seqs
[
0
]
input_sequence
=
input_seqs
[
0
]
seq
s
[
file_id
]
=
input_sequence
seq
_group_dict
[
input_sequence
]
=
[
file_id
]
elif
(
f
.
endswith
(
'.core'
)):
elif
(
f
.
endswith
(
'.core'
)):
with
open
(
path
,
'r'
)
as
fp
:
with
open
(
path
,
'r'
)
as
fp
:
core_str
=
fp
.
read
()
core_str
=
fp
.
read
()
...
@@ -71,27 +105,114 @@ def main(args):
...
@@ -71,27 +105,114 @@ def main(args):
residue_constants
.
restypes_with_x
[
aatype
[
i
]]
residue_constants
.
restypes_with_x
[
aatype
[
i
]]
for
i
in
range
(
len
(
aatype
))
for
i
in
range
(
len
(
aatype
))
])
])
seq
s
[
file_id
]
=
seq
seq
_group_dict
[
seq
]
=
[
file_id
]
else
:
else
:
continue
continue
for
name
,
seq
in
seqs
.
items
():
seq_group_tuples
=
[(
k
,
v
)
for
k
,
v
in
seq_group_dict
.
items
()]
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
run_seq_group_alignments
(
seq_group_tuples
,
alignment_runner
,
args
)
if
(
os
.
path
.
isdir
(
alignment_dir
)):
logging
.
info
(
f
'
{
f
}
has already been processed. Skipping...'
)
continue
os
.
makedirs
(
alignment_dir
)
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
fp
.
write
(
f
'>query
\n
{
seq
}
'
)
alignment_runner
.
run
(
def
main
(
args
):
fasta_path
,
alignment_dir
# Build the alignment tool runner
)
alignment_runner
=
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus_per_task
,
)
os
.
remove
(
fasta_path
)
files
=
list
(
os
.
listdir
(
args
.
input_dir
))
# Do some filtering
if
(
args
.
mmcif_cache
is
not
None
):
with
open
(
args
.
mmcif_cache
,
"r"
)
as
fp
:
cache
=
json
.
load
(
fp
)
else
:
cache
=
None
dirs
=
[]
if
(
cache
is
not
None
and
args
.
filter
):
dirs
=
set
(
os
.
listdir
(
args
.
output_dir
))
def
prot_is_done
(
f
):
prot_id
=
os
.
path
.
splitext
(
f
)[
0
]
if
(
prot_id
in
cache
):
chain_ids
=
cache
[
prot_id
][
"chain_ids"
]
for
c
in
chain_ids
:
full_name
=
prot_id
+
"_"
+
c
if
(
not
full_name
in
dirs
):
return
False
else
:
return
False
return
True
files
=
[
f
for
f
in
files
if
not
prot_is_done
(
f
)]
def
split_up_arglist
(
arglist
):
# Split up the survivors
if
(
os
.
environ
.
get
(
"SLURM_JOB_NUM_NODES"
,
0
)):
num_nodes
=
int
(
os
.
environ
[
"SLURM_JOB_NUM_NODES"
])
if
(
num_nodes
>
1
):
node_id
=
int
(
os
.
environ
[
"SLURM_NODEID"
])
logging
.
warning
(
f
"Num nodes:
{
num_nodes
}
"
)
logging
.
warning
(
f
"Node ID:
{
node_id
}
"
)
arglist
=
arglist
[
node_id
::
num_nodes
]
t_arglist
=
[]
for
i
in
range
(
args
.
no_tasks
):
t_arglist
.
append
(
arglist
[
i
::
args
.
no_tasks
])
return
t_arglist
if
(
cache
is
not
None
and
"seqs"
in
next
(
iter
(
cache
.
values
()))):
seq_group_dict
=
{}
for
f
in
files
:
prot_id
=
os
.
path
.
splitext
(
f
)[
0
]
if
(
prot_id
in
cache
):
prot_cache
=
cache
[
prot_id
]
chains_seqs
=
zip
(
prot_cache
[
"chain_ids"
],
prot_cache
[
"seqs"
]
)
for
chain
,
seq
in
chains_seqs
:
chain_name
=
prot_id
+
"_"
+
chain
if
(
chain_name
not
in
dirs
):
l
=
seq_group_dict
.
setdefault
(
seq
,
[])
l
.
append
(
chain_name
)
func
=
partial
(
run_seq_group_alignments
,
alignment_runner
=
alignment_runner
,
args
=
args
)
seq_groups
=
[(
k
,
v
)
for
k
,
v
in
seq_group_dict
.
items
()]
# Sort them by group length so the tasks are approximately balanced
seq_groups
=
sorted
(
seq_groups
,
key
=
lambda
x
:
len
(
x
[
1
]))
task_arglist
=
[[
a
]
for
a
in
split_up_arglist
(
seq_groups
)]
else
:
func
=
partial
(
parse_and_align
,
alignment_runner
=
alignment_runner
,
args
=
args
,
)
task_arglist
=
[[
a
]
for
a
in
split_up_arglist
(
files
)]
threads
=
[]
for
i
,
task_args
in
enumerate
(
task_arglist
):
print
(
f
"Started thread
{
i
}
..."
)
t
=
threading
.
Thread
(
target
=
func
,
args
=
task_args
)
threads
.
append
(
t
)
t
.
start
()
for
t
in
threads
:
t
.
join
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -111,9 +232,19 @@ if __name__ == "__main__":
...
@@ -111,9 +232,19 @@ if __name__ == "__main__":
help
=
"Whether to crash on parsing errors"
help
=
"Whether to crash on parsing errors"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
"--cpus
_per_task
"
,
type
=
int
,
default
=
cpu_count
()
,
help
=
"Number of CPUs to use"
help
=
"Number of CPUs to use"
)
)
parser
.
add_argument
(
"--mmcif_cache"
,
type
=
str
,
default
=
None
,
help
=
"Path to mmCIF cache. Used to filter files to be parsed"
)
parser
.
add_argument
(
"--no_tasks"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--filter"
,
type
=
bool
,
default
=
True
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment