Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
591d10d2
Commit
591d10d2
authored
Apr 19, 2022
by
Gustaf Ahdritz
Browse files
Fix template bugs
parent
8c169fb6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
56 additions
and
29 deletions
+56
-29
openfold/config.py
openfold/config.py
+2
-0
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+31
-12
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+2
-1
openfold/model/embedders.py
openfold/model/embedders.py
+9
-12
openfold/model/model.py
openfold/model/model.py
+1
-0
openfold/utils/feats.py
openfold/utils/feats.py
+9
-2
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+2
-2
No files found.
openfold/config.py
View file @
591d10d2
...
@@ -255,6 +255,7 @@ config = mlc.ConfigDict(
...
@@ -255,6 +255,7 @@ config = mlc.ConfigDict(
"clamp_prob"
:
0.9
,
"clamp_prob"
:
0.9
,
"max_distillation_msa_clusters"
:
1000
,
"max_distillation_msa_clusters"
:
1000
,
"uniform_recycling"
:
True
,
"uniform_recycling"
:
True
,
"distillation_prob"
:
0.75
,
},
},
"data_module"
:
{
"data_module"
:
{
"use_small_bfd"
:
False
,
"use_small_bfd"
:
False
,
...
@@ -333,6 +334,7 @@ config = mlc.ConfigDict(
...
@@ -333,6 +334,7 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
# 1e-6,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
False
,
},
},
"extra_msa"
:
{
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"extra_msa_embedder"
:
{
...
...
openfold/data/data_transforms.py
View file @
591d10d2
...
@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein):
...
@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein):
def
make_one_hot
(
x
,
num_classes
):
def
make_one_hot
(
x
,
num_classes
):
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
)
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
,
device
=
x
.
device
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
return
x_one_hot
return
x_one_hot
...
@@ -92,9 +92,9 @@ def fix_templates_aatype(protein):
...
@@ -92,9 +92,9 @@ def fix_templates_aatype(protein):
)
)
# Map hhsearch-aatype to our aatype.
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
new_order
=
torch
.
tensor
(
n
um_templates
,
-
1
n
ew_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"aatype"
].
device
,
)
)
.
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
)
)
...
@@ -106,7 +106,8 @@ def correct_msa_restypes(protein):
...
@@ -106,7 +106,8 @@ def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
"""Correct MSA restype to have the same order as rc."""
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order
=
torch
.
tensor
(
[
new_order_list
]
*
protein
[
"msa"
].
shape
[
1
],
dtype
=
protein
[
"msa"
].
dtype
[
new_order_list
]
*
protein
[
"msa"
].
shape
[
1
],
device
=
protein
[
"msa"
].
device
,
).
transpose
(
0
,
1
)
).
transpose
(
0
,
1
)
protein
[
"msa"
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
"msa"
])
protein
[
"msa"
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
"msa"
])
...
@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
...
@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
if
seed
is
not
None
:
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
index_order
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
shuffled
.
device
),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
num_sel
=
min
(
max_seq
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
]
index_order
,
[
num_sel
,
num_seq
-
num_sel
]
...
@@ -242,7 +246,7 @@ def delete_extra_msa(protein):
...
@@ -242,7 +246,7 @@ def delete_extra_msa(protein):
def
block_delete_msa
(
protein
,
config
):
def
block_delete_msa
(
protein
,
config
):
num_seq
=
protein
[
"msa"
].
shape
[
0
]
num_seq
=
protein
[
"msa"
].
shape
[
0
]
block_num_seq
=
torch
.
floor
(
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
,
device
=
protein
[
"msa"
].
device
)
*
config
.
msa_fraction_per_block
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
).
to
(
torch
.
int32
)
...
@@ -275,7 +279,11 @@ def block_delete_msa(protein, config):
...
@@ -275,7 +279,11 @@ def block_delete_msa(protein, config):
@
curry1
@
curry1
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.0
):
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.0
):
weights
=
torch
.
cat
(
weights
=
torch
.
cat
(
[
torch
.
ones
(
21
),
gap_agreement_weight
*
torch
.
ones
(
1
),
torch
.
zeros
(
1
)],
[
torch
.
ones
(
21
,
device
=
protein
[
"msa"
].
device
),
gap_agreement_weight
*
torch
.
ones
(
1
,
device
=
protein
[
"msa"
].
device
),
torch
.
zeros
(
1
,
device
=
protein
[
"msa"
].
device
)
],
0
,
0
,
)
)
...
@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
...
@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
)
)
segment_ids
=
segment_ids
.
expand
(
data
.
shape
)
segment_ids
=
segment_ids
.
expand
(
data
.
shape
)
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add_
(
0
,
segment_ids
,
data
.
float
())
tensor
=
(
torch
.
zeros
(
*
shape
,
device
=
segment_ids
.
device
)
.
scatter_add_
(
0
,
segment_ids
,
data
.
float
())
)
tensor
=
tensor
.
type
(
data
.
dtype
)
tensor
=
tensor
.
type
(
data
.
dtype
)
return
tensor
return
tensor
...
@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""):
...
@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""):
@
curry1
@
curry1
def
add_constant_field
(
protein
,
key
,
value
):
def
add_constant_field
(
protein
,
key
,
value
):
protein
[
key
]
=
torch
.
tensor
(
value
)
protein
[
key
]
=
torch
.
tensor
(
value
,
device
=
protein
[
"msa"
].
device
)
return
protein
return
protein
...
@@ -431,7 +442,11 @@ def make_hhblits_profile(protein):
...
@@ -431,7 +442,11 @@ def make_hhblits_profile(protein):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
"""Create data for BERT on raw MSA."""
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
)
random_aa
=
torch
.
tensor
(
[
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
,
device
=
protein
[
"aatype"
].
device
)
categorical_probs
=
(
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
config
.
uniform_prob
*
random_aa
...
@@ -644,7 +659,11 @@ def make_atom14_masks(protein):
...
@@ -644,7 +659,11 @@ def make_atom14_masks(protein):
def
make_atom14_masks_np
(
batch
):
def
make_atom14_masks_np
(
batch
):
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
,
device
=
batch
[
"aatype"
].
device
),
batch
,
np
.
ndarray
)
out
=
make_atom14_masks
(
batch
)
out
=
make_atom14_masks
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
return
out
...
...
openfold/data/feature_pipeline.py
View file @
591d10d2
...
@@ -40,10 +40,11 @@ def np_to_tensor_dict(
...
@@ -40,10 +40,11 @@ def np_to_tensor_dict(
Returns:
Returns:
A dictionary of features mapping feature names to features. Only the given
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
features are returned, all other ones are filtered out.
"""
"""
tensor_dict
=
{
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
}
return
tensor_dict
return
tensor_dict
...
...
openfold/model/embedders.py
View file @
591d10d2
...
@@ -165,8 +165,6 @@ class RecyclingEmbedder(nn.Module):
...
@@ -165,8 +165,6 @@ class RecyclingEmbedder(nn.Module):
self
.
no_bins
=
no_bins
self
.
no_bins
=
no_bins
self
.
inf
=
inf
self
.
inf
=
inf
self
.
bins
=
None
self
.
linear
=
Linear
(
self
.
no_bins
,
self
.
c_z
)
self
.
linear
=
Linear
(
self
.
no_bins
,
self
.
c_z
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
...
@@ -191,15 +189,14 @@ class RecyclingEmbedder(nn.Module):
...
@@ -191,15 +189,14 @@ class RecyclingEmbedder(nn.Module):
z:
z:
[*, N_res, N_res, C_z] pair embedding update
[*, N_res, N_res, C_z] pair embedding update
"""
"""
if
self
.
bins
is
None
:
bins
=
torch
.
linspace
(
self
.
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
min_bin
,
self
.
max_bin
,
self
.
max_bin
,
self
.
no_bins
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
device
=
x
.
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
# [*, N, C_m]
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
m_update
=
self
.
layer_norm_m
(
m
)
...
@@ -207,7 +204,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -207,7 +204,7 @@ class RecyclingEmbedder(nn.Module):
# This squared method might become problematic in FP16 mode.
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
# couldn't find in time.
squared_bins
=
self
.
bins
**
2
squared_bins
=
bins
**
2
upper
=
torch
.
cat
(
upper
=
torch
.
cat
(
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
)
)
...
...
openfold/model/model.py
View file @
591d10d2
...
@@ -131,6 +131,7 @@ class AlphaFold(nn.Module):
...
@@ -131,6 +131,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t]
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
t
=
build_template_pair_feat
(
single_template_feats
,
single_template_feats
,
use_unit_vector
=
self
.
config
.
template
.
use_unit_vector
,
inf
=
self
.
config
.
template
.
inf
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
,
**
self
.
config
.
template
.
distogram
,
...
...
openfold/utils/feats.py
View file @
591d10d2
...
@@ -90,7 +90,10 @@ def build_template_angle_feat(template_feats):
...
@@ -90,7 +90,10 @@ def build_template_angle_feat(template_feats):
def
build_template_pair_feat
(
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
eps
=
1e-20
,
inf
=
1e8
batch
,
min_bin
,
max_bin
,
no_bins
,
use_unit_vector
=
False
,
eps
=
1e-20
,
inf
=
1e8
):
):
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
...
@@ -101,7 +104,7 @@ def build_template_pair_feat(
...
@@ -101,7 +104,7 @@ def build_template_pair_feat(
(
tpb
[...,
None
,
:]
-
tpb
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
(
tpb
[...,
None
,
:]
-
tpb
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
)
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
tpb
.
device
)
**
2
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
tpb
.
device
)
**
2
upper
=
torch
.
cat
([
lower
[
:
-
1
],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
upper
=
torch
.
cat
([
lower
[
1
:
],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
...
@@ -143,6 +146,10 @@ def build_template_pair_feat(
...
@@ -143,6 +146,10 @@ def build_template_pair_feat(
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
rigid_vec
*
inv_distance_scalar
[...,
None
]
unit_vector
=
rigid_vec
*
inv_distance_scalar
[...,
None
]
if
(
not
use_unit_vector
):
unit_vector
=
unit_vector
*
0.
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
to_concat
.
append
(
template_mask_2d
[...,
None
])
...
...
openfold/utils/rigid_utils.py
View file @
591d10d2
...
@@ -1352,8 +1352,8 @@ class Rigid:
...
@@ -1352,8 +1352,8 @@ class Rigid:
c2_rots
[...,
0
,
0
]
=
cos_c2
c2_rots
[...,
0
,
0
]
=
cos_c2
c2_rots
[...,
0
,
2
]
=
sin_c2
c2_rots
[...,
0
,
2
]
=
sin_c2
c2_rots
[...,
1
,
1
]
=
1
c2_rots
[...,
1
,
1
]
=
1
c
1
_rots
[...,
2
,
0
]
=
-
1
*
sin_c2
c
2
_rots
[...,
2
,
0
]
=
-
1
*
sin_c2
c
1
_rots
[...,
2
,
2
]
=
cos_c2
c
2
_rots
[...,
2
,
2
]
=
cos_c2
c_rots
=
rot_matmul
(
c2_rots
,
c1_rots
)
c_rots
=
rot_matmul
(
c2_rots
,
c1_rots
)
n_xyz
=
rot_vec_mul
(
c_rots
,
n_xyz
)
n_xyz
=
rot_vec_mul
(
c_rots
,
n_xyz
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment