Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
08afe382
Unverified
Commit
08afe382
authored
Aug 03, 2023
by
Dingquan Yu
Committed by
GitHub
Aug 03, 2023
Browse files
Merge branch 'multimer' into permutation
parents
f44e9830
4ca64437
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
659 additions
and
256 deletions
+659
-256
openfold/config.py
openfold/config.py
+123
-107
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+7
-25
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+11
-12
openfold/data/data_transforms_multimer.py
openfold/data/data_transforms_multimer.py
+176
-0
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+2
-1
openfold/data/templates.py
openfold/data/templates.py
+20
-30
openfold/model/embedders.py
openfold/model/embedders.py
+15
-15
openfold/model/model.py
openfold/model/model.py
+4
-7
openfold/model/structure_module.py
openfold/model/structure_module.py
+238
-12
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+4
-12
openfold/utils/all_atom_multimer.py
openfold/utils/all_atom_multimer.py
+6
-6
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+1
-1
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+10
-10
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+34
-14
run_pretrained_openfold.py
run_pretrained_openfold.py
+1
-1
scripts/generate_alphafold_feature_dict.py
scripts/generate_alphafold_feature_dict.py
+1
-1
tests/test_import_weights.py
tests/test_import_weights.py
+6
-2
No files found.
openfold/config.py
View file @
08afe382
...
@@ -155,7 +155,7 @@ def model_config(
...
@@ -155,7 +155,7 @@ def model_config(
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
globals
.
is_multimer
=
True
c
.
globals
.
bfloat16
=
Tru
e
c
.
globals
.
bfloat16
=
Fals
e
c
.
globals
.
bfloat16_output
=
False
c
.
globals
.
bfloat16_output
=
False
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
data
.
common
.
max_recycling_iters
=
20
c
.
data
.
common
.
max_recycling_iters
=
20
...
@@ -593,6 +593,12 @@ config = mlc.ConfigDict(
...
@@ -593,6 +593,12 @@ config = mlc.ConfigDict(
"c_out"
:
37
,
"c_out"
:
37
,
},
},
},
},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance"
:
-
1.
},
},
"relax"
:
{
"relax"
:
{
"max_iterations"
:
0
,
# no max
"max_iterations"
:
0
,
# no max
...
@@ -673,12 +679,6 @@ config = mlc.ConfigDict(
...
@@ -673,12 +679,6 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
"eps"
:
eps
,
},
},
"ema"
:
{
"decay"
:
0.999
},
"ema"
:
{
"decay"
:
0.999
},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance"
:
-
1
}
}
)
)
...
@@ -694,6 +694,20 @@ multimer_model_config_update = {
...
@@ -694,6 +694,20 @@ multimer_model_config_update = {
"max_relative_idx"
:
32
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
"use_chain_relative"
:
True
,
},
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
},
"template"
:
{
"template"
:
{
"distogram"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"min_bin"
:
3.25
,
...
@@ -827,6 +841,8 @@ multimer_model_config_update = {
...
@@ -827,6 +841,8 @@ multimer_model_config_update = {
},
},
"recycle_early_stop_tolerance"
:
0.5
"recycle_early_stop_tolerance"
:
0.5
},
},
"recycle_early_stop_tolerance"
:
0.5
},
"loss"
:
{
"loss"
:
{
"distogram"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"min_bin"
:
2.3125
,
...
...
openfold/data/data_pipeline.py
View file @
08afe382
...
@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
...
@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import
numpy
as
np
import
numpy
as
np
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data.templates
import
get_custom_template_features
from
openfold.data.templates
import
get_custom_template_features
,
empty_template_feats
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
...
@@ -34,22 +34,10 @@ FeatureDict = MutableMapping[str, np.ndarray]
...
@@ -34,22 +34,10 @@ FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
def
empty_template_feats
(
n_res
)
->
FeatureDict
:
return
{
"template_aatype"
:
np
.
zeros
((
0
,
n_res
)).
astype
(
np
.
int64
),
"template_all_atom_positions"
:
np
.
zeros
((
0
,
n_res
,
37
,
3
)).
astype
(
np
.
float32
),
"template_sum_probs"
:
np
.
zeros
((
0
,
1
)).
astype
(
np
.
float32
),
"template_all_atom_mask"
:
np
.
zeros
((
0
,
n_res
,
37
)).
astype
(
np
.
float32
),
}
def
make_template_features
(
def
make_template_features
(
input_sequence
:
str
,
input_sequence
:
str
,
hits
:
Sequence
[
Any
],
hits
:
Sequence
[
Any
],
template_featurizer
:
Any
,
template_featurizer
:
Any
,
query_pdb_code
:
Optional
[
str
]
=
None
,
query_release_date
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
hits_cat
=
sum
(
hits
.
values
(),
[])
hits_cat
=
sum
(
hits
.
values
(),
[])
if
(
len
(
hits_cat
)
==
0
or
template_featurizer
is
None
):
if
(
len
(
hits_cat
)
==
0
or
template_featurizer
is
None
):
...
@@ -61,11 +49,6 @@ def make_template_features(
...
@@ -61,11 +49,6 @@ def make_template_features(
)
)
template_features
=
templates_result
.
features
template_features
=
templates_result
.
features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if
(
template_features
[
"template_aatype"
].
shape
[
0
]
==
0
):
template_features
=
empty_template_feats
(
len
(
input_sequence
))
return
template_features
return
template_features
...
@@ -453,7 +436,8 @@ class AlignmentRunner:
...
@@ -453,7 +436,8 @@ class AlignmentRunner:
if
(
uniprot_database_path
is
not
None
):
if
(
uniprot_database_path
is
not
None
):
self
.
jackhmmer_uniprot_runner
=
jackhmmer
.
Jackhmmer
(
self
.
jackhmmer_uniprot_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniprot_database_path
database_path
=
uniprot_database_path
,
n_cpu
=
no_cpus
)
)
if
(
template_searcher
is
not
None
and
if
(
template_searcher
is
not
None
and
...
@@ -839,7 +823,6 @@ class DataPipeline:
...
@@ -839,7 +823,6 @@ class DataPipeline:
fp
.
close
()
fp
.
close
()
return
all_hits
return
all_hits
def
_get_msas
(
self
,
def
_get_msas
(
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
...
@@ -944,15 +927,14 @@ class DataPipeline:
...
@@ -944,15 +927,14 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
hits
=
self
.
_parse_template_hit
_file
s
(
alignment_dir
,
alignment_dir
,
alignment_index
,
input_sequence
)
alignment_index
,
input_sequence
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
self
.
template_featurizer
,
self
.
template_featurizer
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
])
)
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
alignment_index
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
alignment_index
)
...
@@ -993,7 +975,7 @@ class DataPipeline:
...
@@ -993,7 +975,7 @@ class DataPipeline:
is_distillation
=
is_distillation
is_distillation
=
is_distillation
)
)
hits
=
self
.
_parse_template_hits
(
hits
=
self
.
_parse_template_hit
_file
s
(
alignment_dir
,
alignment_dir
,
alignment_index
,
input_sequence
alignment_index
,
input_sequence
)
)
...
@@ -1025,7 +1007,7 @@ class DataPipeline:
...
@@ -1025,7 +1007,7 @@ class DataPipeline:
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
core_feats
=
make_protein_features
(
protein_object
,
description
)
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hits
(
hits
=
self
.
_parse_template_hit
_file
s
(
alignment_dir
,
alignment_dir
,
alignment_index
,
input_sequence
alignment_index
,
input_sequence
)
)
...
...
openfold/data/data_transforms.py
View file @
08afe382
...
@@ -89,7 +89,6 @@ def make_all_atom_aatype(protein):
...
@@ -89,7 +89,6 @@ def make_all_atom_aatype(protein):
def
fix_templates_aatype
(
protein
):
def
fix_templates_aatype
(
protein
):
# Map one-hot to indices
# Map one-hot to indices
num_templates
=
protein
[
"template_aatype"
].
shape
[
0
]
num_templates
=
protein
[
"template_aatype"
].
shape
[
0
]
if
(
num_templates
>
0
):
protein
[
"template_aatype"
]
=
torch
.
argmax
(
protein
[
"template_aatype"
]
=
torch
.
argmax
(
protein
[
"template_aatype"
],
dim
=-
1
protein
[
"template_aatype"
],
dim
=-
1
)
)
...
...
openfold/data/data_transforms_multimer.py
View file @
08afe382
...
@@ -2,7 +2,9 @@ from typing import Sequence
...
@@ -2,7 +2,9 @@ from typing import Sequence
import
torch
import
torch
from
openfold.config
import
NUM_RES
from
openfold.data.data_transforms
import
curry1
from
openfold.data.data_transforms
import
curry1
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.tensor_utils
import
masked_mean
from
openfold.utils.tensor_utils
import
masked_mean
...
@@ -301,3 +303,177 @@ def make_msa_profile(batch):
...
@@ -301,3 +303,177 @@ def make_msa_profile(batch):
)
)
return
batch
return
batch
def
get_interface_residues
(
positions
,
atom_mask
,
asym_id
,
interface_threshold
):
coord_diff
=
positions
[...,
None
,
:,
:]
-
positions
[...,
None
,
:,
:,
:]
pairwise_dists
=
torch
.
sqrt
(
torch
.
sum
(
coord_diff
**
2
,
dim
=-
1
))
diff_chain_mask
=
(
asym_id
[...,
None
,
:]
!=
asym_id
[...,
:,
None
]).
float
()
pair_mask
=
atom_mask
[...,
None
,
:]
*
atom_mask
[...,
None
,
:,
:]
mask
=
diff_chain_mask
[...,
None
]
*
pair_mask
min_dist_per_res
=
torch
.
where
(
mask
,
pairwise_dists
,
torch
.
inf
).
min
(
dim
=-
1
)
valid_interfaces
=
torch
.
sum
((
min_dist_per_res
<
interface_threshold
).
float
(),
dim
=-
1
)
interface_residues_idxs
=
torch
.
nonzero
(
valid_interfaces
,
as_tuple
=
True
)[
0
]
return
interface_residues_idxs
def
get_spatial_crop_idx
(
protein
,
crop_size
,
interface_threshold
,
generator
):
positions
=
protein
[
"all_atom_positions"
]
atom_mask
=
protein
[
"all_atom_mask"
]
asym_id
=
protein
[
"asym_id"
]
interface_residues
=
get_interface_residues
(
positions
=
positions
,
atom_mask
=
atom_mask
,
asym_id
=
asym_id
,
interface_threshold
=
interface_threshold
)
if
not
torch
.
any
(
interface_residues
):
return
get_contiguous_crop_idx
(
protein
,
crop_size
,
generator
)
target_res
=
interface_residues
[
int
(
torch
.
randint
(
0
,
interface_residues
.
shape
[
-
1
],
(
1
,),
device
=
positions
.
device
,
generator
=
generator
)[
0
])]
ca_idx
=
rc
.
atom_order
[
"CA"
]
ca_positions
=
positions
[...,
ca_idx
,
:]
ca_mask
=
atom_mask
[...,
ca_idx
].
bool
()
coord_diff
=
ca_positions
[...,
None
,
:]
-
ca_positions
[...,
None
,
:,
:]
ca_pairwise_dists
=
torch
.
sqrt
(
torch
.
sum
(
coord_diff
**
2
,
dim
=-
1
))
to_target_distances
=
ca_pairwise_dists
[
target_res
]
break_tie
=
(
torch
.
arange
(
0
,
to_target_distances
.
shape
[
-
1
],
device
=
positions
.
device
).
float
()
*
1e-3
)
to_target_distances
=
torch
.
where
(
ca_mask
[...,
None
],
to_target_distances
,
torch
.
inf
)
+
break_tie
ret
=
torch
.
argsort
(
to_target_distances
)[:
crop_size
]
return
ret
.
sort
().
values
def
randint
(
lower
,
upper
,
generator
,
device
):
return
int
(
torch
.
randint
(
lower
,
upper
+
1
,
(
1
,),
device
=
device
,
generator
=
generator
,
)[
0
])
def
get_contiguous_crop_idx
(
protein
,
crop_size
,
generator
):
num_res
=
protein
[
"aatype"
].
shape
[
0
]
if
num_res
<=
crop_size
:
return
torch
.
arange
(
num_res
)
_
,
chain_lens
=
protein
[
"asym_id"
].
unique
(
return_counts
=
True
)
shuffle_idx
=
torch
.
randperm
(
chain_lens
.
shape
[
-
1
],
device
=
chain_lens
.
device
,
generator
=
generator
)
num_remaining
=
int
(
chain_lens
.
sum
())
num_budget
=
crop_size
crop_idxs
=
[]
asym_offset
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int64
)
for
j
,
idx
in
enumerate
(
shuffle_idx
):
this_len
=
int
(
chain_lens
[
idx
])
num_remaining
-=
this_len
# num res at most we can keep in this ent
crop_size_max
=
min
(
num_budget
,
this_len
)
# num res at least we shall keep in this ent
crop_size_min
=
min
(
this_len
,
max
(
0
,
num_budget
-
num_remaining
))
chain_crop_size
=
randint
(
lower
=
crop_size_min
,
upper
=
crop_size_max
+
1
,
generator
=
generator
,
device
=
chain_lens
.
device
)
chain_start
=
randint
(
lower
=
0
,
upper
=
this_len
-
chain_crop_size
+
1
,
generator
=
generator
,
device
=
chain_lens
.
device
)
crop_idxs
.
append
(
torch
.
arange
(
asym_offset
+
chain_start
,
asym_offset
+
chain_start
+
chain_crop_size
)
)
asym_offset
+=
this_len
num_budget
-=
chain_crop_size
return
torch
.
concat
(
crop_idxs
)
@
curry1
def
random_crop_to_size
(
protein
,
crop_size
,
max_templates
,
shape_schema
,
spatial_crop_prob
,
interface_threshold
,
subsample_templates
=
False
,
seed
=
None
,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
use_spatial_crop
=
torch
.
rand
((
1
,),
device
=
protein
[
"seq_length"
].
device
,
generator
=
g
)
<
spatial_crop_prob
if
use_spatial_crop
:
crop_idxs
=
get_spatial_crop_idx
(
protein
,
crop_size
,
interface_threshold
,
g
)
else
:
crop_idxs
=
get_contiguous_crop_idx
(
protein
,
crop_size
,
g
)
if
"template_mask"
in
protein
:
num_templates
=
protein
[
"template_mask"
].
shape
[
-
1
]
else
:
num_templates
=
0
# No need to subsample templates if there aren't any
subsample_templates
=
subsample_templates
and
num_templates
if
subsample_templates
:
templates_crop_start
=
randint
(
lower
=
0
,
upper
=
num_templates
+
1
,
generator
=
g
,
device
=
protein
[
"seq_length"
].
device
)
templates_select_indices
=
torch
.
randperm
(
num_templates
,
device
=
protein
[
"seq_length"
].
device
,
generator
=
g
)
else
:
templates_crop_start
=
0
num_res_crop_size
=
min
(
int
(
protein
[
"seq_length"
]),
crop_size
)
num_templates_crop_size
=
min
(
num_templates
-
templates_crop_start
,
max_templates
)
for
k
,
v
in
protein
.
items
():
if
k
not
in
shape_schema
or
(
"template"
not
in
k
and
NUM_RES
not
in
shape_schema
[
k
]
):
continue
# randomly permute the templates before cropping them.
if
k
.
startswith
(
"template"
)
and
subsample_templates
:
v
=
v
[
templates_select_indices
]
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
v
.
shape
)):
is_num_res
=
dim_size
==
NUM_RES
if
i
==
0
and
k
.
startswith
(
"template"
):
crop_size
=
num_templates_crop_size
crop_start
=
templates_crop_start
v
=
v
[
slice
(
crop_start
,
crop_start
+
crop_size
)]
elif
is_num_res
:
v
=
torch
.
index_select
(
v
,
i
,
crop_idxs
)
protein
[
k
]
=
v
protein
[
"seq_length"
]
=
protein
[
"seq_length"
].
new_tensor
(
num_res_crop_size
)
return
protein
openfold/data/input_pipeline.py
View file @
08afe382
...
@@ -104,7 +104,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
...
@@ -104,7 +104,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
# the masked locations and secret corrupted locations.
# the masked locations and secret corrupted locations.
transforms
.
append
(
transforms
.
append
(
data_transforms
.
make_masked_msa
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
,
seed
=
msa_seed
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
,
seed
=
(
msa_seed
+
1
)
if
msa_seed
else
None
,
)
)
)
)
...
...
openfold/data/templates.py
View file @
08afe382
...
@@ -89,6 +89,24 @@ TEMPLATE_FEATURES = {
...
@@ -89,6 +89,24 @@ TEMPLATE_FEATURES = {
}
}
def
empty_template_feats
(
n_res
):
return
{
"template_aatype"
:
np
.
zeros
(
(
0
,
n_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_mask"
:
np
.
zeros
(
(
0
,
n_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
0
,
n_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
zeros
((
0
,
1
),
dtype
=
np
.
float32
),
}
def
_get_pdb_id_and_chain
(
hit
:
parsers
.
TemplateHit
)
->
Tuple
[
str
,
str
]:
def
_get_pdb_id_and_chain
(
hit
:
parsers
.
TemplateHit
)
->
Tuple
[
str
,
str
]:
"""Returns PDB id and chain id for an HHSearch Hit."""
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
...
@@ -1163,21 +1181,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1163,21 +1181,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
else
:
else
:
num_res
=
len
(
query_sequence
)
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
# Construct a default template with all zeros.
template_features
=
{
template_features
=
empty_template_feats
(
num_res
)
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
...
@@ -1276,21 +1280,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1276,21 +1280,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
else
:
else
:
num_res
=
len
(
query_sequence
)
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
# Construct a default template with all zeros.
template_features
=
{
template_features
=
empty_template_feats
(
num_res
)
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
return
TemplateSearchResult
(
features
=
template_features
,
features
=
template_features
,
...
...
openfold/model/embedders.py
View file @
08afe382
...
@@ -242,7 +242,7 @@ class InputEmbedderMultimer(nn.Module):
...
@@ -242,7 +242,7 @@ class InputEmbedderMultimer(nn.Module):
entity_id
=
batch
[
"entity_id"
]
entity_id
=
batch
[
"entity_id"
]
entity_id_same
=
(
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:])
entity_id_same
=
(
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:])
rel_feats
.
append
(
entity_id_same
[...,
None
])
rel_feats
.
append
(
entity_id_same
[...,
None
]
.
to
(
dtype
=
rel_pos
.
dtype
)
)
sym_id
=
batch
[
"sym_id"
]
sym_id
=
batch
[
"sym_id"
]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
...
@@ -577,7 +577,7 @@ class TemplateEmbedder(nn.Module):
...
@@ -577,7 +577,7 @@ class TemplateEmbedder(nn.Module):
# a second copy during the stack later on
# a second copy during the stack later on
t_pair
=
z
.
new_zeros
(
t_pair
=
z
.
new_zeros
(
z
.
shape
[:
-
3
]
+
z
.
shape
[:
-
3
]
+
(
n_templ
,
n
,
n
,
self
.
config
.
template_pair_embedder
.
c_t
)
(
n_templ
,
n
,
n
,
self
.
config
.
template_pair_embedder
.
c_
ou
t
)
)
)
for
i
in
range
(
n_templ
):
for
i
in
range
(
n_templ
):
...
@@ -667,17 +667,17 @@ class TemplatePairEmbedderMultimer(nn.Module):
...
@@ -667,17 +667,17 @@ class TemplatePairEmbedderMultimer(nn.Module):
):
):
super
(
TemplatePairEmbedderMultimer
,
self
).
__init__
()
super
(
TemplatePairEmbedderMultimer
,
self
).
__init__
()
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
)
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_z
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_z
)
self
.
query_embedding_linear
=
Linear
(
c_z
,
c_out
)
self
.
query_embedding_linear
=
Linear
(
c_z
,
c_out
,
init
=
'relu'
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
x_linear
=
Linear
(
1
,
c_out
)
self
.
x_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
y_linear
=
Linear
(
1
,
c_out
)
self
.
y_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
z_linear
=
Linear
(
1
,
c_out
)
self
.
z_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
backbone_mask_linear
=
Linear
(
1
,
c_out
)
self
.
backbone_mask_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
def
forward
(
self
,
def
forward
(
self
,
template_dgram
:
torch
.
Tensor
,
template_dgram
:
torch
.
Tensor
,
...
@@ -812,10 +812,10 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -812,10 +812,10 @@ class TemplateEmbedderMultimer(nn.Module):
single_template_embeds
=
{}
single_template_embeds
=
{}
act
=
0.
act
=
0.
template_positions
,
pseudo_beta_mask
=
(
template_positions
,
pseudo_beta_mask
=
pseudo_beta_fn
(
single_template_feats
[
"template_
pseudo_beta
"
],
single_template_feats
[
"template_
aatype
"
],
single_template_feats
[
"template_
pseudo_beta_mask
"
],
single_template_feats
[
"template_
all_atom_positions
"
],
)
single_template_feats
[
"template_all_atom_mask"
]
)
template_dgram
=
dgram_from_positions
(
template_dgram
=
dgram_from_positions
(
template_positions
,
template_positions
,
...
...
openfold/model/model.py
View file @
08afe382
...
@@ -186,11 +186,6 @@ class AlphaFold(nn.Module):
...
@@ -186,11 +186,6 @@ class AlphaFold(nn.Module):
if
self
.
config
.
recycle_early_stop_tolerance
<
0
:
if
self
.
config
.
recycle_early_stop_tolerance
<
0
:
return
False
return
False
if
no_batch_dims
==
0
:
prev_pos
=
prev_pos
.
unsqueeze
(
dim
=
0
)
next_pos
=
next_pos
.
unsqueeze
(
dim
=
0
)
mask
=
mask
.
unsqueeze
(
dim
=
0
)
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
sq_diff
=
(
distances
(
prev_pos
[...,
ca_idx
,
:])
-
distances
(
next_pos
[...,
ca_idx
,
:]))
**
2
sq_diff
=
(
distances
(
prev_pos
[...,
ca_idx
,
:])
-
distances
(
next_pos
[...,
ca_idx
,
:]))
**
2
mask
=
mask
[...,
None
]
*
mask
[...,
None
,
:]
mask
=
mask
[...,
None
]
*
mask
[...,
None
,
:]
...
@@ -265,7 +260,7 @@ class AlphaFold(nn.Module):
...
@@ -265,7 +260,7 @@ class AlphaFold(nn.Module):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
x_prev
=
pseudo_beta_fn
(
pseudo_beta_
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
...
@@ -279,10 +274,12 @@ class AlphaFold(nn.Module):
...
@@ -279,10 +274,12 @@ class AlphaFold(nn.Module):
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
m_1_prev
,
z_prev
,
z_prev
,
x_prev
,
pseudo_beta_
x_prev
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
)
)
del
pseudo_beta_x_prev
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
to
(
m_1_prev_emb
.
device
)
m
=
m
.
to
(
m_1_prev_emb
.
device
)
z
=
z
.
to
(
z_prev
.
device
)
z
=
z
.
to
(
z_prev
.
device
)
...
...
openfold/model/structure_module.py
View file @
08afe382
...
@@ -166,12 +166,14 @@ class PointProjection(nn.Module):
...
@@ -166,12 +166,14 @@ class PointProjection(nn.Module):
c_hidden
:
int
,
c_hidden
:
int
,
num_points
:
int
,
num_points
:
int
,
no_heads
:
int
,
no_heads
:
int
,
is_multimer
:
bool
,
return_local_points
:
bool
=
False
,
return_local_points
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
num_points
=
num_points
self
.
num_points
=
num_points
self
.
is_multimer
=
is_multimer
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
)
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
)
...
@@ -181,24 +183,19 @@ class PointProjection(nn.Module):
...
@@ -181,24 +183,19 @@ class PointProjection(nn.Module):
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO: Needs to run in high precision during training
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
self
.
linear
(
activations
)
out_shape
=
points_local
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
self
.
num_points
,
3
)
if
isinstance
(
rigids
,
Rigid3Array
):
if
self
.
is_multimer
:
points_local
=
points_local
.
reshape
(
points_local
=
points_local
.
view
(
*
points_local
.
shape
[:
-
1
],
points_local
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
self
.
no_heads
,
-
1
,
)
)
points_local
=
torch
.
split
(
points_local
=
torch
.
split
(
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
)
points_local
=
torch
.
stack
(
points_local
,
dim
=-
1
)
points_local
=
torch
.
stack
(
points_local
,
dim
=-
1
)
.
view
(
out_shape
)
if
not
isinstance
(
rigids
,
Rigid3Array
):
points_local
=
points_local
.
reshape
(
*
points_local
.
shape
[:
-
2
],
self
.
no_heads
,
-
1
,
3
)
points_global
=
rigids
[...,
None
,
None
].
apply
(
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply
(
points_local
)
if
(
self
.
return_local_points
):
if
(
self
.
return_local_points
):
...
@@ -260,7 +257,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -260,7 +257,8 @@ class InvariantPointAttention(nn.Module):
self
.
linear_q_points
=
PointProjection
(
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_qk_points
,
self
.
no_heads
self
.
no_heads
,
self
.
is_multimer
)
)
if
(
is_multimer
):
if
(
is_multimer
):
...
@@ -270,12 +268,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -270,12 +268,14 @@ class InvariantPointAttention(nn.Module):
self
.
c_s
,
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_qk_points
,
self
.
no_heads
,
self
.
no_heads
,
self
.
is_multimer
)
)
self
.
linear_v_points
=
PointProjection
(
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
self
.
no_v_points
,
self
.
no_v_points
,
self
.
no_heads
,
self
.
no_heads
,
self
.
is_multimer
)
)
else
:
else
:
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
...
@@ -283,6 +283,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -283,6 +283,7 @@ class InvariantPointAttention(nn.Module):
self
.
c_s
,
self
.
c_s
,
self
.
no_qk_points
+
self
.
no_v_points
,
self
.
no_qk_points
+
self
.
no_v_points
,
self
.
no_heads
,
self
.
no_heads
,
self
.
is_multimer
)
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
...
@@ -504,6 +505,230 @@ class InvariantPointAttention(nn.Module):
...
@@ -504,6 +505,230 @@ class InvariantPointAttention(nn.Module):
return
s
return
s
#TODO: This module follows the refactoring done in IPA for multimer. Running the regular IPA above
# in multimer mode should be equivalent, but tests do not pass unless using this version. Determine
# whether or not the increase in test error matters in practice.
class
InvariantPointAttentionMultimer
(
nn
.
Module
):
"""
Implements Algorithm 22.
"""
def
__init__
(
self
,
c_s
:
int
,
c_z
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
no_qk_points
:
int
,
no_v_points
:
int
,
inf
:
float
=
1e5
,
eps
:
float
=
1e-8
,
is_multimer
:
bool
=
True
,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
"""
super
(
InvariantPointAttentionMultimer
,
self
).
__init__
()
self
.
c_s
=
c_s
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
no_qk_points
=
no_qk_points
self
.
no_v_points
=
no_v_points
self
.
inf
=
inf
self
.
eps
=
eps
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc
=
self
.
c_hidden
*
self
.
no_heads
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
,
is_multimer
=
True
)
self
.
linear_k
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_v
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_k_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
,
is_multimer
=
True
)
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
no_v_points
,
self
.
no_heads
,
is_multimer
=
True
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
self
.
head_weights
=
nn
.
Parameter
(
torch
.
zeros
((
no_heads
)))
ipa_point_weights_init_
(
self
.
head_weights
)
concat_out_dim
=
self
.
no_heads
*
(
self
.
c_z
+
self
.
c_hidden
+
self
.
no_v_points
*
4
)
self
.
linear_out
=
Linear
(
concat_out_dim
,
self
.
c_s
,
init
=
"final"
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
2
)
def
forward
(
self
,
s
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
r
:
Union
[
Rigid
,
Rigid3Array
],
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
z
=
[
z
]
a
=
0.
point_variance
=
(
max
(
self
.
no_qk_points
,
1
)
*
9.0
/
2
)
point_weights
=
math
.
sqrt
(
1.0
/
point_variance
)
softplus
=
lambda
x
:
torch
.
logaddexp
(
x
,
torch
.
zeros_like
(
x
))
head_weights
=
softplus
(
self
.
head_weights
)
point_weights
=
point_weights
*
head_weights
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H, P_qk]
q_pts
=
Vec3Array
.
from_array
(
self
.
linear_q_points
(
s
,
r
))
# [*, N_res, H, P_qk, 3]
k_pts
=
Vec3Array
.
from_array
(
self
.
linear_k_points
(
s
,
r
))
pt_att
=
square_euclidean_distance
(
q_pts
.
unsqueeze
(
-
3
),
k_pts
.
unsqueeze
(
-
4
),
epsilon
=
0.
)
pt_att
=
torch
.
sum
(
pt_att
*
point_weights
[...,
None
],
dim
=-
1
)
*
(
-
0.5
)
a
=
a
+
pt_att
scalar_variance
=
max
(
self
.
c_hidden
,
1
)
*
1.
scalar_weights
=
math
.
sqrt
(
1.0
/
scalar_variance
)
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
s
)
k
=
self
.
linear_k
(
s
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
*
scalar_weights
a
=
a
+
torch
.
einsum
(
'...qhc,...khc->...qkh'
,
q
,
k
)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b
=
self
.
linear_b
(
z
[
0
])
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
z
[
0
])
==
2
)
z
[
0
]
=
z
[
0
].
cpu
()
a
=
a
+
b
# [*, N_res, N_res]
square_mask
=
mask
.
unsqueeze
(
-
1
)
*
mask
.
unsqueeze
(
-
2
)
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
a
=
a
+
square_mask
.
unsqueeze
(
-
1
)
a
=
a
*
math
.
sqrt
(
1.
/
3
)
# Normalize by number of logit terms (3)
a
=
self
.
softmax
(
a
)
# [*, N_res, H * C_hidden]
v
=
self
.
linear_v
(
s
)
# [*, N_res, H, C_hidden]
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
torch
.
einsum
(
'...qkh, ...khc->...qhc'
,
a
,
v
)
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, N_res, H, P_v, 3]
v_pts
=
Vec3Array
.
from_array
(
self
.
linear_v_points
(
s
,
r
))
# [*, N_res, H, P_v]
o_pt
=
v_pts
[...,
None
,
:,
:,
:]
*
a
.
unsqueeze
(
-
1
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
# o_pt = Vec3Array(
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3),
# )
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
o_pt
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, H, P_v]
o_pt
=
r
[...,
None
].
apply_inverse_to_point
(
o_pt
)
o_pt_flat
=
[
o_pt
.
x
,
o_pt
.
y
,
o_pt
.
z
]
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
epsilon
=
1e-8
)
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
o_pair
=
torch
.
einsum
(
'...ijh, ...ijc->...ihc'
,
a
,
z
[
0
].
to
(
dtype
=
a
.
dtype
))
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
# [*, N_res, C_s]
s
=
self
.
linear_out
(
torch
.
cat
(
(
o
,
*
o_pt_flat
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
[
0
].
dtype
)
)
return
s
class
BackboneUpdate
(
nn
.
Module
):
class
BackboneUpdate
(
nn
.
Module
):
"""
"""
Implements part of Algorithm 23.
Implements part of Algorithm 23.
...
@@ -670,7 +895,8 @@ class StructureModule(nn.Module):
...
@@ -670,7 +895,8 @@ class StructureModule(nn.Module):
self
.
linear_in
=
Linear
(
self
.
c_s
,
self
.
c_s
)
self
.
linear_in
=
Linear
(
self
.
c_s
,
self
.
c_s
)
self
.
ipa
=
InvariantPointAttention
(
ipa
=
InvariantPointAttention
if
not
self
.
is_multimer
else
InvariantPointAttentionMultimer
self
.
ipa
=
ipa
(
self
.
c_s
,
self
.
c_s
,
self
.
c_z
,
self
.
c_z
,
self
.
c_ipa
,
self
.
c_ipa
,
...
...
openfold/model/triangular_multiplicative_update.py
View file @
08afe382
...
@@ -521,12 +521,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
...
@@ -521,12 +521,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
def
compute_projection
(
pair
,
mask
):
def
compute_projection
(
pair
,
mask
):
p
=
compute_projection_helper
(
pair
,
mask
)
p
=
compute_projection_helper
(
pair
,
mask
)
if
self
.
_outgoing
:
left
=
p
[...,
:
self
.
c_hidden
]
left
=
p
[...,
:
self
.
c_hidden
]
right
=
p
[...,
self
.
c_hidden
:]
right
=
p
[...,
self
.
c_hidden
:]
else
:
left
=
p
[...,
self
.
c_hidden
:]
right
=
p
[...,
:
self
.
c_hidden
]
return
left
,
right
return
left
,
right
...
@@ -580,12 +576,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
...
@@ -580,12 +576,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
ab
=
ab
*
self
.
sigmoid
(
self
.
linear_ab_g
(
z
))
ab
=
ab
*
self
.
sigmoid
(
self
.
linear_ab_g
(
z
))
ab
=
ab
*
self
.
linear_ab_p
(
z
)
ab
=
ab
*
self
.
linear_ab_p
(
z
)
if
self
.
_outgoing
:
a
=
ab
[...,
:
self
.
c_hidden
]
a
=
ab
[...,
:
self
.
c_hidden
]
b
=
ab
[...,
self
.
c_hidden
:]
b
=
ab
[...,
self
.
c_hidden
:]
else
:
b
=
ab
[...,
:
self
.
c_hidden
]
a
=
ab
[...,
self
.
c_hidden
:]
# Prevents overflow of torch.matmul in combine projections in
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
# reduced-precision modes
...
...
openfold/utils/all_atom_multimer.py
View file @
08afe382
...
@@ -36,9 +36,9 @@ def get_rc_tensor(rc_np, aatype):
...
@@ -36,9 +36,9 @@ def get_rc_tensor(rc_np, aatype):
def
atom14_to_atom37
(
def
atom14_to_atom37
(
atom14_data
:
torch
.
Tensor
,
# (*, N, 14, ...)
atom14_data
:
torch
.
Tensor
,
# (*, N, 14, ...)
aatype
:
torch
.
Tensor
# (*, N)
aatype
:
torch
.
Tensor
# (*, N)
)
->
torch
.
Tensor
:
# (*, N, 37, ...)
)
->
Tuple
:
# (*, N, 37, ...)
"""Convert atom14 to atom37 representation."""
"""Convert atom14 to atom37 representation."""
idx_atom37_to_atom14
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_TO_ATOM14
,
aatype
)
idx_atom37_to_atom14
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_TO_ATOM14
,
aatype
)
.
long
()
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
atom37_data
=
tensor_utils
.
batched_gather
(
atom37_data
=
tensor_utils
.
batched_gather
(
atom14_data
,
atom14_data
,
...
@@ -50,10 +50,10 @@ def atom14_to_atom37(
...
@@ -50,10 +50,10 @@ def atom14_to_atom37(
if
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
2
:
if
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
2
:
atom37_data
*=
atom37_mask
atom37_data
*=
atom37_mask
elif
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
3
:
elif
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
3
:
atom37_data
*=
atom37_mask
[...,
None
].
as
type
(
atom37_data
.
dtype
)
atom37_data
*=
atom37_mask
[...,
None
].
to
(
d
type
=
atom37_data
.
dtype
)
else
:
else
:
raise
ValueError
(
"Incorrectly shaped data"
)
raise
ValueError
(
"Incorrectly shaped data"
)
return
atom37_data
return
atom37_data
,
atom37_mask
def
atom37_to_atom14
(
aatype
,
all_atom_pos
,
all_atom_mask
):
def
atom37_to_atom14
(
aatype
,
all_atom_pos
,
all_atom_mask
):
...
@@ -230,13 +230,13 @@ def torsion_angles_to_frames(
...
@@ -230,13 +230,13 @@ def torsion_angles_to_frames(
num_residues
=
aatype
.
shape
[
-
1
]
num_residues
=
aatype
.
shape
[
-
1
]
sin_angles
=
torch
.
cat
(
sin_angles
=
torch
.
cat
(
[
[
torch
.
zeros_like
(
aatype
).
unsqueeze
(),
torch
.
zeros_like
(
aatype
).
unsqueeze
(
dim
=-
1
),
sin_angles
,
sin_angles
,
],
],
dim
=-
1
)
dim
=-
1
)
cos_angles
=
torch
.
cat
(
cos_angles
=
torch
.
cat
(
[
[
torch
.
ones_like
(
aatype
).
unsqueeze
(),
torch
.
ones_like
(
aatype
).
unsqueeze
(
dim
=-
1
),
cos_angles
cos_angles
],
],
dim
=-
1
dim
=-
1
...
...
openfold/utils/geometry/quat_rigid.py
View file @
08afe382
...
@@ -20,7 +20,7 @@ class QuatRigid(nn.Module):
...
@@ -20,7 +20,7 @@ class QuatRigid(nn.Module):
def
forward
(
self
,
activations
:
torch
.
Tensor
)
->
Rigid3Array
:
def
forward
(
self
,
activations
:
torch
.
Tensor
)
->
Rigid3Array
:
# NOTE: During training, this needs to be run in higher precision
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
)
)
rigid_flat
=
self
.
linear
(
activations
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
if
(
self
.
full_quat
):
...
...
openfold/utils/geometry/rotation_matrix.py
View file @
08afe382
...
@@ -172,20 +172,20 @@ class Rot3Array:
...
@@ -172,20 +172,20 @@ class Rot3Array:
)
->
Rot3Array
:
)
->
Rot3Array
:
"""Construct Rot3Array from components of quaternion."""
"""Construct Rot3Array from components of quaternion."""
if
normalize
:
if
normalize
:
inv_norm
=
torch
.
rsqrt
(
eps
+
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
)
inv_norm
=
torch
.
rsqrt
(
torch
.
clamp
(
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
,
min
=
eps
)
)
w
=
w
*
inv_norm
w
=
w
*
inv_norm
x
=
x
*
inv_norm
x
=
x
*
inv_norm
y
=
y
*
inv_norm
y
=
y
*
inv_norm
z
=
z
*
inv_norm
z
=
z
*
inv_norm
xx
=
1
-
2
*
(
y
**
2
+
z
**
2
)
xx
=
1
.0
-
2
.0
*
(
y
**
2
+
z
**
2
)
xy
=
2
*
(
x
*
y
-
w
*
z
)
xy
=
2
.0
*
(
x
*
y
-
w
*
z
)
xz
=
2
*
(
x
*
z
+
w
*
y
)
xz
=
2
.0
*
(
x
*
z
+
w
*
y
)
yx
=
2
*
(
x
*
y
+
w
*
z
)
yx
=
2
.0
*
(
x
*
y
+
w
*
z
)
yy
=
1
-
2
*
(
x
**
2
+
z
**
2
)
yy
=
1
.0
-
2
.0
*
(
x
**
2
+
z
**
2
)
yz
=
2
*
(
y
*
z
-
w
*
x
)
yz
=
2
.0
*
(
y
*
z
-
w
*
x
)
zx
=
2
*
(
x
*
z
-
w
*
y
)
zx
=
2
.0
*
(
x
*
z
-
w
*
y
)
zy
=
2
*
(
y
*
z
+
w
*
x
)
zy
=
2
.0
*
(
y
*
z
+
w
*
x
)
zz
=
1
-
2
*
(
x
**
2
+
y
**
2
)
zz
=
1
.0
-
2
.0
*
(
x
**
2
+
y
**
2
)
return
cls
(
xx
,
xy
,
xz
,
yx
,
yy
,
yz
,
zx
,
zy
,
zz
)
return
cls
(
xx
,
xy
,
xz
,
yx
,
yy
,
yz
,
zx
,
zy
,
zz
)
def
reshape
(
self
,
new_shape
):
def
reshape
(
self
,
new_shape
):
...
...
openfold/utils/import_weights.py
View file @
08afe382
...
@@ -28,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
...
@@ -28,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
# With Param, a poor man's enum with attributes (Rust-style)
class
ParamType
(
Enum
):
class
ParamType
(
Enum
):
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
lambda
w
:
w
.
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
unsqueeze
(
-
1
)
if
len
(
w
.
shape
)
==
1
else
w
.
transpose
(
-
1
,
-
2
)
)
)
LinearWeightMHA
=
partial
(
LinearWeightMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
...
@@ -58,6 +58,7 @@ class Param:
...
@@ -58,6 +58,7 @@ class Param:
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
param_type
:
ParamType
=
ParamType
.
Other
param_type
:
ParamType
=
ParamType
.
Other
stacked
:
bool
=
False
stacked
:
bool
=
False
swap
:
bool
=
False
def
process_translation_dict
(
d
,
top_layer
=
True
):
def
process_translation_dict
(
d
,
top_layer
=
True
):
...
@@ -101,6 +102,7 @@ def stacked(param_dict_list, out=None):
...
@@ -101,6 +102,7 @@ def stacked(param_dict_list, out=None):
param
=
[
param
.
param
for
param
in
v
],
param
=
[
param
.
param
for
param
in
v
],
param_type
=
v
[
0
].
param_type
,
param_type
=
v
[
0
].
param_type
,
stacked
=
True
,
stacked
=
True
,
swap
=
v
[
0
].
swap
)
)
out
[
k
]
=
stacked_param
out
[
k
]
=
stacked_param
...
@@ -122,6 +124,11 @@ def assign(translation_dict, orig_weights):
...
@@ -122,6 +124,11 @@ def assign(translation_dict, orig_weights):
try
:
try
:
weights
=
list
(
map
(
param_type
.
transformation
,
weights
))
weights
=
list
(
map
(
param_type
.
transformation
,
weights
))
for
p
,
w
in
zip
(
ref
,
weights
):
for
p
,
w
in
zip
(
ref
,
weights
):
if
param
.
swap
:
index
=
p
.
shape
[
0
]
//
2
p
[:
index
].
copy_
(
w
[
index
:])
p
[
index
:].
copy_
(
w
[:
index
])
else
:
p
.
copy_
(
w
)
p
.
copy_
(
w
)
except
:
except
:
print
(
k
)
print
(
k
)
...
@@ -145,12 +152,24 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -145,12 +152,24 @@ def generate_translation_dict(model, version, is_multimer=False):
LinearBiasMultimer
=
lambda
l
:
(
LinearBiasMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearBiasMultimer
)
Param
(
l
,
param_type
=
ParamType
.
LinearBiasMultimer
)
)
)
LinearWeightSwap
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
,
swap
=
True
))
LinearBiasSwap
=
lambda
l
:
(
Param
(
l
,
swap
=
True
))
LinearParams
=
lambda
l
:
{
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
"weights"
:
LinearWeight
(
l
.
weight
),
"bias"
:
LinearBias
(
l
.
bias
),
"bias"
:
LinearBias
(
l
.
bias
),
}
}
LinearParamsMHA
=
lambda
l
:
{
"weights"
:
LinearWeightMHA
(
l
.
weight
),
"bias"
:
LinearBiasMHA
(
l
.
bias
),
}
LinearParamsSwap
=
lambda
l
:
{
"weights"
:
LinearWeightSwap
(
l
.
weight
),
"bias"
:
LinearBiasSwap
(
l
.
bias
),
}
LinearParamsMultimer
=
lambda
l
:
{
LinearParamsMultimer
=
lambda
l
:
{
"weights"
:
LinearWeightMultimer
(
l
.
weight
),
"weights"
:
LinearWeightMultimer
(
l
.
weight
),
"bias"
:
LinearBiasMultimer
(
l
.
bias
),
"bias"
:
LinearBiasMultimer
(
l
.
bias
),
...
@@ -194,10 +213,11 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -194,10 +213,11 @@ def generate_translation_dict(model, version, is_multimer=False):
def
TriMulOutParams
(
tri_mul
,
outgoing
=
True
):
def
TriMulOutParams
(
tri_mul
,
outgoing
=
True
):
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
version
):
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
version
):
lin_param_type
=
LinearParams
if
outgoing
else
LinearParamsSwap
d
=
{
d
=
{
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"projection"
:
L
in
earParams
(
tri_mul
.
linear_ab_p
),
"projection"
:
l
in
_param_type
(
tri_mul
.
linear_ab_p
),
"gate"
:
L
in
earParams
(
tri_mul
.
linear_ab_g
),
"gate"
:
l
in
_param_type
(
tri_mul
.
linear_ab_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
}
}
else
:
else
:
...
@@ -276,24 +296,24 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -276,24 +296,24 @@ def generate_translation_dict(model, version, is_multimer=False):
}
}
PointProjectionParams
=
lambda
pp
:
{
PointProjectionParams
=
lambda
pp
:
{
"point_projection"
:
LinearParamsM
ultimer
(
"point_projection"
:
LinearParamsM
HA
(
pp
.
linear
,
pp
.
linear
,
),
),
}
}
IPAParamsMultimer
=
lambda
ipa
:
{
IPAParamsMultimer
=
lambda
ipa
:
{
"q_scalar_projection"
:
{
"q_scalar_projection"
:
{
"weights"
:
LinearWeightM
ultimer
(
"weights"
:
LinearWeightM
HA
(
ipa
.
linear_q
.
weight
,
ipa
.
linear_q
.
weight
,
),
),
},
},
"k_scalar_projection"
:
{
"k_scalar_projection"
:
{
"weights"
:
LinearWeightM
ultimer
(
"weights"
:
LinearWeightM
HA
(
ipa
.
linear_k
.
weight
,
ipa
.
linear_k
.
weight
,
),
),
},
},
"v_scalar_projection"
:
{
"v_scalar_projection"
:
{
"weights"
:
LinearWeightM
ultimer
(
"weights"
:
LinearWeightM
HA
(
ipa
.
linear_v
.
weight
,
ipa
.
linear_v
.
weight
,
),
),
},
},
...
@@ -574,7 +594,7 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -574,7 +594,7 @@ def generate_translation_dict(model, version, is_multimer=False):
"template_pair_embedding_0"
:
LinearParams
(
"template_pair_embedding_0"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
dgram_linear
temp_embedder
.
template_pair_embedder
.
dgram_linear
),
),
"template_pair_embedding_1"
:
LinearParams
Multimer
(
"template_pair_embedding_1"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
),
),
"template_pair_embedding_2"
:
LinearParams
(
"template_pair_embedding_2"
:
LinearParams
(
...
@@ -583,16 +603,16 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -583,16 +603,16 @@ def generate_translation_dict(model, version, is_multimer=False):
"template_pair_embedding_3"
:
LinearParams
(
"template_pair_embedding_3"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
),
),
"template_pair_embedding_4"
:
LinearParams
Multimer
(
"template_pair_embedding_4"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
x_linear
temp_embedder
.
template_pair_embedder
.
x_linear
),
),
"template_pair_embedding_5"
:
LinearParams
Multimer
(
"template_pair_embedding_5"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
y_linear
temp_embedder
.
template_pair_embedder
.
y_linear
),
),
"template_pair_embedding_6"
:
LinearParams
Multimer
(
"template_pair_embedding_6"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
z_linear
temp_embedder
.
template_pair_embedder
.
z_linear
),
),
"template_pair_embedding_7"
:
LinearParams
Multimer
(
"template_pair_embedding_7"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
),
),
"template_pair_embedding_8"
:
LinearParams
(
"template_pair_embedding_8"
:
LinearParams
(
...
@@ -600,7 +620,7 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -600,7 +620,7 @@ def generate_translation_dict(model, version, is_multimer=False):
),
),
"template_embedding_iteration"
:
tps_blocks_params
,
"template_embedding_iteration"
:
tps_blocks_params
,
"output_layer_norm"
:
LayerNormParams
(
"output_layer_norm"
:
LayerNormParams
(
model
.
templa
te_embedder
.
template_pair_stack
.
layer_norm
te
mp
_embedder
.
template_pair_stack
.
layer_norm
),
),
},
},
"output_linear"
:
LinearParams
(
"output_linear"
:
LinearParams
(
...
...
run_pretrained_openfold.py
View file @
08afe382
...
@@ -431,7 +431,7 @@ if __name__ == "__main__":
...
@@ -431,7 +431,7 @@ if __name__ == "__main__":
help
=
"""Postfix for output prediction filenames"""
help
=
"""Postfix for output prediction filenames"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--data_random_seed"
,
type
=
str
,
default
=
None
"--data_random_seed"
,
type
=
int
,
default
=
None
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--skip_relaxation"
,
action
=
"store_true"
,
default
=
False
,
"--skip_relaxation"
,
action
=
"store_true"
,
default
=
False
,
...
...
scripts/generate_alphafold_feature_dict.py
View file @
08afe382
...
@@ -45,7 +45,7 @@ def main(args):
...
@@ -45,7 +45,7 @@ def main(args):
uniref90_database_path
=
args
.
uniref90_database_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uni
clust
30_database_path
=
args
.
uni
clust
30_database_path
,
uni
ref
30_database_path
=
args
.
uni
ref
30_database_path
,
small_bfd_database_path
=
None
,
small_bfd_database_path
=
None
,
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
template_searcher
=
template_searcher
,
template_searcher
=
template_searcher
,
...
...
tests/test_import_weights.py
View file @
08afe382
...
@@ -15,7 +15,9 @@
...
@@ -15,7 +15,9 @@
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
pathlib
import
Path
from
tests.config
import
consts
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.import_weights
import
import_jax_weights_
from
openfold.utils.import_weights
import
import_jax_weights_
...
@@ -23,15 +25,17 @@ from openfold.utils.import_weights import import_jax_weights_
...
@@ -23,15 +25,17 @@ from openfold.utils.import_weights import import_jax_weights_
class
TestImportWeights
(
unittest
.
TestCase
):
class
TestImportWeights
(
unittest
.
TestCase
):
def
test_import_jax_weights_
(
self
):
def
test_import_jax_weights_
(
self
):
npz_path
=
"
openfold/resources/params/params_
model_1_ptm
.npz"
npz_path
=
Path
(
__file__
).
parent
.
resolve
()
/
f
"../
openfold/resources/params/params_
{
consts
.
model
}
.npz"
c
=
model_config
(
"model_1_ptm"
)
c
=
model_config
(
consts
.
model
)
c
.
globals
.
blocks_per_ckpt
=
None
c
.
globals
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
model
.
eval
()
import_jax_weights_
(
import_jax_weights_
(
model
,
model
,
npz_path
,
npz_path
,
version
=
consts
.
model
)
)
data
=
np
.
load
(
npz_path
)
data
=
np
.
load
(
npz_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment