Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d71d37ff
"examples/llm/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "8891aa0cc8d1bad420890a60c5d63d253834ee60"
Commit
d71d37ff
authored
Oct 27, 2021
by
Gustaf Ahdritz
Browse files
Fix empty template feature bug
parent
81ae777d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
58 deletions
+93
-58
openfold/config.py
openfold/config.py
+0
-3
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+20
-4
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+27
-19
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+19
-9
openfold/model/model.py
openfold/model/model.py
+27
-23
No files found.
openfold/config.py
View file @
d71d37ff
...
...
@@ -189,7 +189,6 @@ config = mlc.ConfigDict(
"max_msa_clusters"
:
128
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"num_ensemble"
:
1
,
"crop"
:
False
,
"crop_size"
:
None
,
"supervised"
:
False
,
...
...
@@ -202,7 +201,6 @@ config = mlc.ConfigDict(
"max_msa_clusters"
:
128
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"num_ensemble"
:
1
,
"crop"
:
False
,
"crop_size"
:
None
,
"supervised"
:
True
,
...
...
@@ -215,7 +213,6 @@ config = mlc.ConfigDict(
"max_msa_clusters"
:
128
,
"max_template_hits"
:
20
,
"max_templates"
:
4
,
"num_ensemble"
:
1
,
"crop"
:
True
,
"crop_size"
:
256
,
"supervised"
:
True
,
...
...
openfold/data/data_pipeline.py
View file @
d71d37ff
...
...
@@ -21,12 +21,21 @@ import numpy as np
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
empty_template_feats
(
n_res
)
->
FeatureDict
:
return
{
"template_aatype"
:
np
.
zeros
((
0
,
n_res
)).
astype
(
np
.
int64
),
"template_all_atom_positions"
:
np
.
zeros
((
0
,
n_res
,
37
,
3
)).
astype
(
np
.
float32
),
"template_sum_probs"
:
np
.
zeros
((
0
,
1
)).
astype
(
np
.
float32
),
"template_all_atom_mask"
:
np
.
zeros
((
0
,
n_res
,
37
)).
astype
(
np
.
float32
),
}
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
...
...
@@ -340,7 +349,7 @@ class DataPipeline:
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
if
(
len
(
hits_cat
)
==
0
):
template_features
=
{}
template_features
=
empty_template_feats
(
len
(
input_sequence
))
else
:
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
...
...
@@ -389,7 +398,7 @@ class DataPipeline:
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
if
(
len
(
hits_cat
)
==
0
):
template_features
=
{}
template_features
=
empty_template_feats
(
len
(
input_sequence
))
else
:
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
...
...
@@ -399,6 +408,12 @@ class DataPipeline:
)
template_features
=
templates_result
.
features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if
(
template_features
[
"template_aatype"
].
shape
[
0
]
==
0
):
template_features
=
empty_template_feats
(
len
(
input_sequence
))
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
)
return
{
**
mmcif_feats
,
**
template_features
,
**
msa_features
}
...
...
@@ -415,13 +430,14 @@ class DataPipeline:
pdb_str
=
pdb_path
protein_object
=
protein
.
from_pdb_string
(
pdb_str
)
input_sequence
=
protein_object
.
aatype
pdb_feats
=
make_pdb_features
(
protein_object
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
if
(
len
(
hits_cat
)
==
0
):
template_features
=
{}
template_features
=
empty_template_feats
(
len
(
input_sequence
))
else
:
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
...
...
openfold/data/data_transforms.py
View file @
d71d37ff
...
...
@@ -85,17 +85,18 @@ def make_all_atom_aatype(protein):
def
fix_templates_aatype
(
protein
):
# Map one-hot to indices
num_templates
=
protein
[
"template_aatype"
].
shape
[
0
]
protein
[
"template_aatype"
]
=
torch
.
argmax
(
protein
[
"template_aatype"
],
dim
=-
1
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
)
if
(
num_templates
>
0
):
protein
[
"template_aatype"
]
=
torch
.
argmax
(
protein
[
"template_aatype"
],
dim
=-
1
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
)
return
protein
...
...
@@ -169,10 +170,13 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
return
protein
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
):
def
sample_msa
(
protein
,
max_seq
,
keep_extra
,
seed
=
None
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq
=
protein
[
"msa"
].
shape
[
0
]
shuffled
=
torch
.
randperm
(
num_seq
-
1
)
+
1
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
...
...
@@ -1095,18 +1099,22 @@ def random_crop_to_size(
seed
=
None
,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
seq_length
=
protein
[
"seq_length"
]
if
"template_mask"
in
protein
:
num_templates
=
protein
[
"template_mask"
].
shape
[
-
1
]
else
:
num_templates
=
protein
[
"aatype"
].
new_zeros
((
1
,))
num_templates
=
0
num_res_crop_size
=
min
(
int
(
seq_length
),
crop_size
)
# No need to subsample templates if there aren't any
subsample_templates
=
subsample_templates
and
num_templates
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
num_res_crop_size
=
min
(
int
(
seq_length
),
crop_size
)
def
_randint
(
lower
,
upper
):
return
int
(
torch
.
randint
(
...
...
openfold/data/input_pipeline.py
View file @
d71d37ff
...
...
@@ -86,8 +86,16 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
msa_seed
=
None
if
(
not
common_cfg
.
resample_msa_in_recycling
):
msa_seed
=
ensemble_seed
transforms
.
append
(
data_transforms
.
sample_msa
(
max_msa_clusters
,
keep_extra
=
True
)
data_transforms
.
sample_msa
(
max_msa_clusters
,
keep_extra
=
True
,
seed
=
msa_seed
,
)
)
if
"masked_msa"
in
common_cfg
:
...
...
@@ -122,7 +130,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
mode_cfg
.
max_templates
,
crop_feats
,
mode_cfg
.
subsample_templates
,
seed
=
ensemble_seed
,
seed
=
ensemble_seed
+
1
,
)
)
transforms
.
append
(
...
...
@@ -159,21 +167,23 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
tensors
=
compose
(
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
))(
tensors
)
no_templates
=
True
if
(
"template_aatype"
in
tensors
):
no_templates
=
tensors
[
"template_aatype"
].
shape
[
0
]
==
0
num_ensemble
=
mode_cfg
.
num_ensemble
nonensembled
=
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
,
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
if
common_cfg
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
num_ensemble
*=
num_recycling
+
1
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_
ensemble
)
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_
recycling
+
1
)
)
return
tensors
...
...
openfold/model/model.py
View file @
d71d37ff
...
...
@@ -241,31 +241,35 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
z
,
pair_mask
,
no_batch_dims
,
chunk_size
,
)
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
self
.
config
.
template
.
embed_angles
:
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
dim
=-
3
template_mask
=
feats
[
"template_mask"
]
if
(
torch
.
any
(
template_mask
)):
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
z
,
pair_mask
,
no_batch_dims
,
chunk_size
,
)
# [*, S, N]
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
axis
=-
2
)
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
self
.
config
.
template
.
embed_angles
:
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
dim
=-
3
)
# [*, S, N]
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
dim
=-
2
)
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment