Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
651949b2
Commit
651949b2
authored
Oct 08, 2021
by
Gustaf Ahdritz
Browse files
Move config script, reformat data_transforms
parent
f4150fa1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
111 additions
and
42 deletions
+111
-42
openfold/config.py
openfold/config.py
+0
-0
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+107
-38
run_pretrained_alphafold.py
run_pretrained_alphafold.py
+1
-1
tests/compare_utils.py
tests/compare_utils.py
+1
-1
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_model.py
tests/test_model.py
+1
-1
No files found.
config.py
→
openfold/
config.py
View file @
651949b2
File moved
openfold/features/data_transforms.py
View file @
651949b2
...
@@ -5,7 +5,7 @@ import numpy as np
...
@@ -5,7 +5,7 @@ import numpy as np
import
torch
import
torch
from
operator
import
add
from
operator
import
add
from
config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.
config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
MSA_FEATURE_NAMES
=
[
MSA_FEATURE_NAMES
=
[
...
@@ -29,7 +29,9 @@ def make_seq_mask(protein):
...
@@ -29,7 +29,9 @@ def make_seq_mask(protein):
return
protein
return
protein
def
make_template_mask
(
protein
):
def
make_template_mask
(
protein
):
protein
[
'template_mask'
]
=
torch
.
ones
(
protein
[
'template_aatype'
].
shape
[
0
],
dtype
=
torch
.
float32
)
protein
[
'template_mask'
]
=
torch
.
ones
(
protein
[
'template_aatype'
].
shape
[
0
],
dtype
=
torch
.
float32
)
return
protein
return
protein
def
curry1
(
f
):
def
curry1
(
f
):
...
@@ -42,7 +44,9 @@ def curry1(f):
...
@@ -42,7 +44,9 @@ def curry1(f):
@
curry1
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
'is_distillation'
]
=
torch
.
tensor
(
float
(
distillation
),
dtype
=
torch
.
float32
)
protein
[
'is_distillation'
]
=
torch
.
tensor
(
float
(
distillation
),
dtype
=
torch
.
float32
)
return
protein
return
protein
def
make_all_atom_aatype
(
protein
):
def
make_all_atom_aatype
(
protein
):
...
@@ -55,14 +59,20 @@ def fix_templates_aatype(protein):
...
@@ -55,14 +59,20 @@ def fix_templates_aatype(protein):
protein
[
'template_aatype'
]
=
torch
.
argmax
(
protein
[
'template_aatype'
],
dim
=-
1
)
protein
[
'template_aatype'
]
=
torch
.
argmax
(
protein
[
'template_aatype'
],
dim
=-
1
)
# Map hhsearch-aatype to our aatype.
# Map hhsearch-aatype to our aatype.
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int32
).
expand
(
num_templates
,
-
1
)
new_order
=
torch
.
tensor
(
protein
[
'template_aatype'
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
'template_aatype'
])
new_order_list
,
dtype
=
torch
.
int32
).
expand
(
num_templates
,
-
1
)
protein
[
'template_aatype'
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
'template_aatype'
]
)
return
protein
return
protein
def
correct_msa_restypes
(
protein
):
def
correct_msa_restypes
(
protein
):
"""Correct MSA restype to have the same order as residue_constants."""
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
([
new_order_list
]
*
protein
[
'msa'
].
shape
[
1
],
dtype
=
protein
[
'msa'
].
dtype
).
transpose
(
0
,
1
)
new_order
=
torch
.
tensor
(
[
new_order_list
]
*
protein
[
'msa'
].
shape
[
1
],
dtype
=
protein
[
'msa'
].
dtype
).
transpose
(
0
,
1
)
protein
[
'msa'
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
'msa'
])
protein
[
'msa'
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
'msa'
])
perm_matrix
=
np
.
zeros
((
22
,
22
),
dtype
=
np
.
float32
)
perm_matrix
=
np
.
zeros
((
22
,
22
),
dtype
=
np
.
float32
)
...
@@ -94,7 +104,9 @@ def squeeze_features(protein):
...
@@ -94,7 +104,9 @@ def squeeze_features(protein):
return
protein
return
protein
def
make_protein_crop_to_size_seed
(
protein
):
def
make_protein_crop_to_size_seed
(
protein
):
protein
[
'random_crop_to_size_seed'
]
=
torch
.
distributions
.
Uniform
(
low
=
torch
.
int32
,
high
=
torch
.
int32
).
sample
((
2
))
protein
[
'random_crop_to_size_seed'
]
=
torch
.
distributions
.
Uniform
(
low
=
torch
.
int32
,
high
=
torch
.
int32
).
sample
((
2
)
)
return
protein
return
protein
@
curry1
@
curry1
...
@@ -110,8 +122,10 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
...
@@ -110,8 +122,10 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
torch
.
rand
(
protein
[
'aatype'
].
shape
)
<
replace_proportion
torch
.
rand
(
protein
[
'aatype'
].
shape
)
<
replace_proportion
)
)
protein
[
'aatype'
]
=
torch
.
where
(
aatype_mask
,
torch
.
ones_like
(
protein
[
'aatype'
])
*
x_idx
,
protein
[
'aatype'
]
=
torch
.
where
(
protein
[
'aatype'
])
aatype_mask
,
torch
.
ones_like
(
protein
[
'aatype'
])
*
x_idx
,
protein
[
'aatype'
]
)
return
protein
return
protein
@
curry1
@
curry1
...
@@ -151,7 +165,11 @@ def delete_extra_msa(protein):
...
@@ -151,7 +165,11 @@ def delete_extra_msa(protein):
@
curry1
@
curry1
def
block_delete_msa
(
protein
,
config
):
def
block_delete_msa
(
protein
,
config
):
num_seq
=
protein
[
'msa'
].
shape
[
0
]
num_seq
=
protein
[
'msa'
].
shape
[
0
]
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
if
config
.
randomize_num_blocks
:
if
config
.
randomize_num_blocks
:
nb
=
torch
.
distributions
.
uniform
.
Uniform
(
0
,
config
.
num_blocks
+
1
).
sample
()
nb
=
torch
.
distributions
.
uniform
.
Uniform
(
0
,
config
.
num_blocks
+
1
).
sample
()
...
@@ -195,9 +213,12 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
...
@@ -195,9 +213,12 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
# in an optimized fashion to avoid possible memory or computation blowup.
agreement
=
torch
.
matmul
(
torch
.
reshape
(
extra_one_hot
,
[
extra_num_seq
,
num_res
*
23
]),
agreement
=
torch
.
matmul
(
torch
.
reshape
(
sample_one_hot
*
weights
,
[
num_seq
,
num_res
*
23
]).
transpose
(
0
,
1
),
torch
.
reshape
(
extra_one_hot
,
[
extra_num_seq
,
num_res
*
23
]),
)
torch
.
reshape
(
sample_one_hot
*
weights
,
[
num_seq
,
num_res
*
23
]
).
transpose
(
0
,
1
),
)
# Assign each sequence in the extra sequences to the closest MSA sample
# Assign each sequence in the extra sequences to the closest MSA sample
protein
[
'extra_cluster_assignment'
]
=
torch
.
argmax
(
agreement
,
dim
=
1
).
to
(
torch
.
int64
)
protein
[
'extra_cluster_assignment'
]
=
torch
.
argmax
(
agreement
,
dim
=
1
).
to
(
torch
.
int64
)
...
@@ -213,14 +234,18 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
...
@@ -213,14 +234,18 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
:param num_segments: The number of segments.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
:return: A tensor of same data type as the data argument.
"""
"""
assert
all
([
i
in
data
.
shape
for
i
in
segment_ids
.
shape
]),
"segment_ids.shape should be a prefix of data.shape"
# segment_ids.shape should be a prefix of data.shape
assert
all
([
i
in
data
.
shape
for
i
in
segment_ids
.
shape
])
# segment_ids is a 1-D tensor repeat it to have the same shape as data
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if
len
(
segment_ids
.
shape
)
==
1
:
if
len
(
segment_ids
.
shape
)
==
1
:
s
=
torch
.
prod
(
torch
.
tensor
(
data
.
shape
[
1
:])).
long
()
s
=
torch
.
prod
(
torch
.
tensor
(
data
.
shape
[
1
:])).
long
()
segment_ids
=
segment_ids
.
repeat_interleave
(
s
).
view
(
segment_ids
.
shape
[
0
],
*
data
.
shape
[
1
:])
segment_ids
=
segment_ids
.
repeat_interleave
(
s
).
view
(
segment_ids
.
shape
[
0
],
*
data
.
shape
[
1
:]
)
assert
data
.
shape
==
segment_ids
.
shape
,
"data.shape and segment_ids.shape should be equal"
# data.shape and segment_ids.shape should be equal
assert
data
.
shape
==
segment_ids
.
shape
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add
(
0
,
segment_ids
,
data
.
float
())
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add
(
0
,
segment_ids
,
data
.
float
())
...
@@ -232,7 +257,9 @@ def summarize_clusters(protein):
...
@@ -232,7 +257,9 @@ def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq
=
protein
[
'msa'
].
shape
[
0
]
num_seq
=
protein
[
'msa'
].
shape
[
0
]
def
csum
(
x
):
def
csum
(
x
):
return
unsorted_segment_sum
(
x
,
protein
[
'extra_cluster_assignment'
],
num_seq
)
return
unsorted_segment_sum
(
x
,
protein
[
'extra_cluster_assignment'
],
num_seq
)
mask
=
protein
[
'extra_msa_mask'
]
mask
=
protein
[
'extra_msa_mask'
]
mask_counts
=
1e-6
+
protein
[
'msa_mask'
]
+
csum
(
mask
)
# Include center
mask_counts
=
1e-6
+
protein
[
'msa_mask'
]
+
csum
(
mask
)
# Include center
...
@@ -292,7 +319,9 @@ def add_constant_field(protein, key, value):
...
@@ -292,7 +319,9 @@ def add_constant_field(protein, key, value):
def
shaped_categorical
(
probs
,
epsilon
=
1e-10
):
def
shaped_categorical
(
probs
,
epsilon
=
1e-10
):
ds
=
probs
.
shape
ds
=
probs
.
shape
num_classes
=
ds
[
-
1
]
num_classes
=
ds
[
-
1
]
distribution
=
torch
.
distributions
.
categorical
.
Categorical
(
torch
.
reshape
(
probs
+
epsilon
,[
-
1
,
num_classes
]))
distribution
=
torch
.
distributions
.
categorical
.
Categorical
(
torch
.
reshape
(
probs
+
epsilon
,[
-
1
,
num_classes
])
)
counts
=
distribution
.
sample
()
counts
=
distribution
.
sample
()
return
torch
.
reshape
(
counts
,
ds
[:
-
1
])
return
torch
.
reshape
(
counts
,
ds
[:
-
1
])
...
@@ -323,7 +352,9 @@ def make_masked_msa(protein, config, replace_fraction):
...
@@ -323,7 +352,9 @@ def make_masked_msa(protein, config, replace_fraction):
pad_shapes
[
1
]
=
1
pad_shapes
[
1
]
=
1
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
assert
mask_prob
>=
0.
assert
mask_prob
>=
0.
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
)
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
)
sh
=
protein
[
'msa'
].
shape
sh
=
protein
[
'msa'
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
...
@@ -339,7 +370,14 @@ def make_masked_msa(protein, config, replace_fraction):
...
@@ -339,7 +370,14 @@ def make_masked_msa(protein, config, replace_fraction):
return
protein
return
protein
@
curry1
@
curry1
def
make_fixed_size
(
protein
,
shape_schema
,
msa_cluster_size
,
extra_msa_size
,
num_res
=
0
,
num_templates
=
0
):
def
make_fixed_size
(
protein
,
shape_schema
,
msa_cluster_size
,
extra_msa_size
,
num_res
=
0
,
num_templates
=
0
):
"""Guess at the MSA and sequence dimension to make fixed size."""
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map
=
{
pad_size_map
=
{
...
@@ -355,9 +393,13 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
...
@@ -355,9 +393,13 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
continue
continue
shape
=
list
(
v
.
shape
)
shape
=
list
(
v
.
shape
)
schema
=
shape_schema
[
k
]
schema
=
shape_schema
[
k
]
msd
=
"Rank mismatch between shape and shape schema for"
assert
len
(
shape
)
==
len
(
schema
),
(
assert
len
(
shape
)
==
len
(
schema
),
(
f
'Rank mismatch between shape and shape schema for
{
k
}
:
{
shape
}
vs
{
schema
}
'
)
f
'
{
msg
}
{
k
}
:
{
shape
}
vs
{
schema
}
'
pad_size
=
[
pad_size_map
.
get
(
s2
,
None
)
or
s1
for
(
s1
,
s2
)
in
zip
(
shape
,
schema
)]
)
pad_size
=
[
pad_size_map
.
get
(
s2
,
None
)
or
s1
for
(
s1
,
s2
)
in
zip
(
shape
,
schema
)
]
padding
=
[(
0
,
p
-
v
.
shape
[
i
])
for
i
,
p
in
enumerate
(
pad_size
)]
padding
=
[(
0
,
p
-
v
.
shape
[
i
])
for
i
,
p
in
enumerate
(
pad_size
)]
padding
.
reverse
()
padding
.
reverse
()
...
@@ -371,8 +413,11 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
...
@@ -371,8 +413,11 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
@
curry1
@
curry1
def
make_msa_feat
(
protein
):
def
make_msa_feat
(
protein
):
"""Create and concatenate MSA features."""
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for compatibility with domain datasets.
# Whether there is a domain break. Always zero for chains, but keeping for
has_break
=
torch
.
clip
(
protein
[
'between_segment_residues'
].
to
(
torch
.
float32
),
0
,
1
)
# compatibility with domain datasets.
has_break
=
torch
.
clip
(
protein
[
'between_segment_residues'
].
to
(
torch
.
float32
),
0
,
1
)
aatype_1hot
=
make_one_hot
(
protein
[
'aatype'
],
21
)
aatype_1hot
=
make_one_hot
(
protein
[
'aatype'
],
21
)
target_feat
=
[
target_feat
=
[
...
@@ -391,14 +436,20 @@ def make_msa_feat(protein):
...
@@ -391,14 +436,20 @@ def make_msa_feat(protein):
]
]
if
'cluster_profile'
in
protein
:
if
'cluster_profile'
in
protein
:
deletion_mean_value
=
(
torch
.
atan
(
protein
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
))
deletion_mean_value
=
(
torch
.
atan
(
protein
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
)
)
msa_feat
.
extend
([
protein
[
'cluster_profile'
],
msa_feat
.
extend
([
protein
[
'cluster_profile'
],
torch
.
unsqueeze
(
deletion_mean_value
,
dim
=-
1
),
torch
.
unsqueeze
(
deletion_mean_value
,
dim
=-
1
),
])
])
if
'extra_deletion_matrix'
in
protein
:
if
'extra_deletion_matrix'
in
protein
:
protein
[
'extra_has_deletion'
]
=
torch
.
clip
(
protein
[
'extra_deletion_matrix'
],
0.
,
1.
)
protein
[
'extra_has_deletion'
]
=
torch
.
clip
(
protein
[
'extra_deletion_value'
]
=
torch
.
atan
(
protein
[
'extra_deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
protein
[
'extra_deletion_matrix'
],
0.
,
1.
)
protein
[
'extra_deletion_value'
]
=
torch
.
atan
(
protein
[
'extra_deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
protein
[
'msa_feat'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
protein
[
'msa_feat'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
protein
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
protein
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
...
@@ -422,35 +473,51 @@ def make_atom14_masks(protein):
...
@@ -422,35 +473,51 @@ def make_atom14_masks(protein):
restype_atom14_mask
=
[]
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]
]
restype_atom14_to_atom37
.
append
([
restype_atom14_to_atom37
.
append
([
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
residue_constants
.
atom_types
for
name
in
residue_constants
.
atom_types
])
])
# Since all 14 atoms are not present in every residue, use this mask to tell which atom is there in this residue
# Since all 14 atoms are not present in every residue, use this mask to
# tell which atom is there in this residue
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_to_atom37
=
torch
.
tensor
(
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
)
restype_atom14_to_atom37
=
torch
.
tensor
(
restype_atom37_to_atom14
=
torch
.
tensor
(
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
)
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
restype_atom14_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
torch
.
float32
)
)
restype_atom37_to_atom14
=
torch
.
tensor
(
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
)
restype_atom14_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
torch
.
float32
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
torch
.
index_select
(
restype_atom14_to_atom37
,
0
,
protein
[
'aatype'
])
residx_atom14_to_atom37
=
torch
.
index_select
(
residx_atom14_mask
=
torch
.
index_select
(
restype_atom14_mask
,
0
,
protein
[
'aatype'
])
restype_atom14_to_atom37
,
0
,
protein
[
'aatype'
]
)
residx_atom14_mask
=
torch
.
index_select
(
restype_atom14_mask
,
0
,
protein
[
'aatype'
]
)
protein
[
'atom14_atom_exists'
]
=
residx_atom14_mask
protein
[
'atom14_atom_exists'
]
=
residx_atom14_mask
protein
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
protein
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
# create the gather indices for mapping back
# create the gather indices for mapping back
residx_atom37_to_atom14
=
torch
.
index_select
(
restype_atom37_to_atom14
,
0
,
protein
[
'aatype'
])
residx_atom37_to_atom14
=
torch
.
index_select
(
restype_atom37_to_atom14
,
0
,
protein
[
'aatype'
]
)
protein
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
protein
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
# create the corresponding mask
# create the corresponding mask
...
@@ -462,7 +529,9 @@ def make_atom14_masks(protein):
...
@@ -462,7 +529,9 @@ def make_atom14_masks(protein):
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
torch
.
index_select
(
restype_atom37_mask
,
0
,
protein
[
'aatype'
])
residx_atom37_mask
=
torch
.
index_select
(
restype_atom37_mask
,
0
,
protein
[
'aatype'
]
)
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
return
protein
return
protein
\ No newline at end of file
run_pretrained_alphafold.py
View file @
651949b2
...
@@ -31,7 +31,7 @@ import time
...
@@ -31,7 +31,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
config
import
model_config
from
openfold.
config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
import
openfold.np.relax.relax
as
relax
...
...
tests/compare_utils.py
View file @
651949b2
...
@@ -6,7 +6,7 @@ import unittest
...
@@ -6,7 +6,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
from
config
import
model_config
from
openfold.
config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.import_weights
import
import_jax_weights_
from
openfold.utils.import_weights
import
import_jax_weights_
from
tests.config
import
consts
from
tests.config
import
consts
...
...
tests/test_import_weights.py
View file @
651949b2
...
@@ -16,7 +16,7 @@ import torch
...
@@ -16,7 +16,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
config
import
model_config
from
openfold.
config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.import_weights
import
import_jax_weights_
from
openfold.utils.import_weights
import
import_jax_weights_
...
...
tests/test_model.py
View file @
651949b2
...
@@ -17,7 +17,7 @@ import torch
...
@@ -17,7 +17,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
config
import
*
from
openfold.
config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment