Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
961d86fc
Commit
961d86fc
authored
Oct 08, 2021
by
Gustaf Ahdritz
Browse files
Merge branch 'features' into main
parents
0f1b1968
e8e0b66f
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
3206 additions
and
11 deletions
+3206
-11
config.py
config.py
+92
-0
openfold/features/__init__.py
openfold/features/__init__.py
+0
-0
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+468
-0
openfold/features/feature_pipeline.py
openfold/features/feature_pipeline.py
+79
-0
openfold/features/input_pipeline.py
openfold/features/input_pipeline.py
+131
-0
openfold/features/mmcif_parsing.py
openfold/features/mmcif_parsing.py
+371
-0
openfold/features/np/__init__.py
openfold/features/np/__init__.py
+0
-0
openfold/features/np/data_pipeline.py
openfold/features/np/data_pipeline.py
+188
-0
openfold/features/np/hhblits.py
openfold/features/np/hhblits.py
+141
-0
openfold/features/np/hhsearch.py
openfold/features/np/hhsearch.py
+77
-0
openfold/features/np/jackhmmer.py
openfold/features/np/jackhmmer.py
+185
-0
openfold/features/np/kalign.py
openfold/features/np/kalign.py
+90
-0
openfold/features/np/utils.py
openfold/features/np/utils.py
+27
-0
openfold/features/parsers.py
openfold/features/parsers.py
+349
-0
openfold/features/templates.py
openfold/features/templates.py
+895
-0
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+1
-1
run_pretrained_alphafold.py
run_pretrained_alphafold.py
+112
-10
No files found.
config.py
View file @
961d86fc
...
@@ -66,7 +66,99 @@ chunk_size = mlc.FieldReference(4, field_type=int)
...
@@ -66,7 +66,99 @@ chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
NUM_RES
=
'num residues placeholder'
NUM_MSA_SEQ
=
'msa placeholder'
NUM_EXTRA_SEQ
=
'extra msa placeholder'
NUM_TEMPLATES
=
'num templates placeholder'
config
=
mlc
.
ConfigDict
({
config
=
mlc
.
ConfigDict
({
'data'
:
{
'common'
:
{
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'max_extra_msa'
:
1024
,
'msa_cluster_features'
:
True
,
'num_recycle'
:
3
,
'reduce_msa_clusters_by_max_templates'
:
False
,
'resample_msa_in_recycling'
:
True
,
'template_features'
:
[
'template_all_atom_positions'
,
'template_sum_probs'
,
'template_aatype'
,
'template_all_atom_masks'
,
# 'template_domain_names'
],
'unsupervised_features'
:
[
'aatype'
,
'residue_index'
,
'msa'
,
# 'sequence', #'domain_name',
'num_alignments'
,
'seq_length'
,
'between_segment_residues'
,
'deletion_matrix'
],
'use_templates'
:
True
,
},
'eval'
:
{
'feat'
:
{
'aatype'
:
[
NUM_RES
],
'all_atom_mask'
:
[
NUM_RES
,
None
],
'all_atom_positions'
:
[
NUM_RES
,
None
,
None
],
'alt_chi_angles'
:
[
NUM_RES
,
None
],
'atom14_alt_gt_exists'
:
[
NUM_RES
,
None
],
'atom14_alt_gt_positions'
:
[
NUM_RES
,
None
,
None
],
'atom14_atom_exists'
:
[
NUM_RES
,
None
],
'atom14_atom_is_ambiguous'
:
[
NUM_RES
,
None
],
'atom14_gt_exists'
:
[
NUM_RES
,
None
],
'atom14_gt_positions'
:
[
NUM_RES
,
None
,
None
],
'atom37_atom_exists'
:
[
NUM_RES
,
None
],
'backbone_affine_mask'
:
[
NUM_RES
],
'backbone_affine_tensor'
:
[
NUM_RES
,
None
],
'bert_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'chi_angles'
:
[
NUM_RES
,
None
],
'chi_mask'
:
[
NUM_RES
,
None
],
'extra_deletion_value'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_has_deletion'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa_mask'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa_row_mask'
:
[
NUM_EXTRA_SEQ
],
'is_distillation'
:
[],
'msa_feat'
:
[
NUM_MSA_SEQ
,
NUM_RES
,
None
],
'msa_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'msa_row_mask'
:
[
NUM_MSA_SEQ
],
'pseudo_beta'
:
[
NUM_RES
,
None
],
'pseudo_beta_mask'
:
[
NUM_RES
],
'random_crop_to_size_seed'
:
[
None
],
'residue_index'
:
[
NUM_RES
],
'residx_atom14_to_atom37'
:
[
NUM_RES
,
None
],
'residx_atom37_to_atom14'
:
[
NUM_RES
,
None
],
'resolution'
:
[],
'rigidgroups_alt_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'rigidgroups_group_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_group_is_ambiguous'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'seq_length'
:
[],
'seq_mask'
:
[
NUM_RES
],
'target_feat'
:
[
NUM_RES
,
None
],
'template_aatype'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_all_atom_masks'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_all_atom_positions'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_backbone_affine_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_backbone_affine_tensor'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_mask'
:
[
NUM_TEMPLATES
],
'template_pseudo_beta'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_pseudo_beta_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_sum_probs'
:
[
NUM_TEMPLATES
,
None
],
'true_msa'
:
[
NUM_MSA_SEQ
,
NUM_RES
]
},
'fixed_size'
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
}
},
# Recurring FieldReferences that can be changed globally here
# Recurring FieldReferences that can be changed globally here
"globals"
:
{
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
...
...
openfold/features/__init__.py
0 → 100644
View file @
961d86fc
openfold/features/data_transforms.py
0 → 100644
View file @
961d86fc
import
itertools
from
functools
import
reduce
import
numpy
as
np
import
torch
from
operator
import
add
from
config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
MSA_FEATURE_NAMES
=
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
]
def
cast_to_64bit_ints
(
protein
):
# We keep all ints as int64
for
k
,
v
in
protein
.
items
():
if
v
.
dtype
==
torch
.
int32
:
protein
[
k
]
=
v
.
type
(
torch
.
int64
)
return
protein
def
make_one_hot
(
x
,
num_classes
):
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
return
x_one_hot
def
make_seq_mask
(
protein
):
protein
[
'seq_mask'
]
=
torch
.
ones
(
protein
[
'aatype'
].
shape
,
dtype
=
torch
.
float32
)
return
protein
def
make_template_mask
(
protein
):
protein
[
'template_mask'
]
=
torch
.
ones
(
protein
[
'template_aatype'
].
shape
[
0
],
dtype
=
torch
.
float32
)
return
protein
def
curry1
(
f
):
"""Supply all arguments but the first."""
def
fc
(
*
args
,
**
kwargs
):
return
lambda
x
:
f
(
x
,
*
args
,
**
kwargs
)
return
fc
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
'is_distillation'
]
=
torch
.
tensor
(
float
(
distillation
),
dtype
=
torch
.
float32
)
return
protein
def
make_all_atom_aatype
(
protein
):
protein
[
'all_atom_aatype'
]
=
protein
[
'aatype'
]
return
protein
def
fix_templates_aatype
(
protein
):
# Map one-hot to indices
num_templates
=
protein
[
'template_aatype'
].
shape
[
0
]
protein
[
'template_aatype'
]
=
torch
.
argmax
(
protein
[
'template_aatype'
],
dim
=-
1
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int32
).
expand
(
num_templates
,
-
1
)
protein
[
'template_aatype'
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
'template_aatype'
])
return
protein
def
correct_msa_restypes
(
protein
):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
([
new_order_list
]
*
protein
[
'msa'
].
shape
[
1
],
dtype
=
protein
[
'msa'
].
dtype
).
transpose
(
0
,
1
)
protein
[
'msa'
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
'msa'
])
perm_matrix
=
np
.
zeros
((
22
,
22
),
dtype
=
np
.
float32
)
perm_matrix
[
range
(
len
(
new_order_list
)),
new_order_list
]
=
1.
for
k
in
protein
:
if
'profile'
in
k
:
num_dim
=
protein
[
k
].
shape
.
as_list
()[
-
1
]
assert
num_dim
in
[
20
,
21
,
22
],
(
'num_dim for %s out of expected range: %s'
%
(
k
,
num_dim
))
protein
[
k
]
=
torch
.
dot
(
protein
[
k
],
perm_matrix
[:
num_dim
,
:
num_dim
])
return
protein
def
squeeze_features
(
protein
):
"""Remove singleton and repeated dimensions in protein features."""
protein
[
'aatype'
]
=
torch
.
argmax
(
protein
[
'aatype'
],
dim
=-
1
)
for
k
in
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
'superfamily'
,
'deletion_matrix'
,
'resolution'
,
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_masks'
]:
if
k
in
protein
:
final_dim
=
protein
[
k
].
shape
[
-
1
]
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
protein
[
k
]
=
torch
.
squeeze
(
protein
[
k
],
dim
=-
1
)
for
k
in
[
'seq_length'
,
'num_alignments'
]:
if
k
in
protein
:
protein
[
k
]
=
protein
[
k
][
0
]
return
protein
def
make_protein_crop_to_size_seed
(
protein
):
protein
[
'random_crop_to_size_seed'
]
=
torch
.
distributions
.
Uniform
(
low
=
torch
.
int32
,
high
=
torch
.
int32
).
sample
((
2
))
return
protein
@
curry1
def
randomly_replace_msa_with_unknown
(
protein
,
replace_proportion
):
"""Replace a portion of the MSA with 'X'."""
msa_mask
=
(
torch
.
rand
(
protein
[
'msa'
].
shape
)
<
replace_proportion
)
x_idx
=
20
gap_idx
=
21
msa_mask
=
torch
.
logical_and
(
msa_mask
,
protein
[
'msa'
]
!=
gap_idx
)
protein
[
'msa'
]
=
torch
.
where
(
msa_mask
,
torch
.
ones_like
(
protein
[
'msa'
])
*
x_idx
,
protein
[
'msa'
])
aatype_mask
=
(
torch
.
rand
(
protein
[
'aatype'
].
shape
)
<
replace_proportion
)
protein
[
'aatype'
]
=
torch
.
where
(
aatype_mask
,
torch
.
ones_like
(
protein
[
'aatype'
])
*
x_idx
,
protein
[
'aatype'
])
return
protein
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.
"""
num_seq
=
protein
[
'msa'
].
shape
[
0
]
shuffled
=
torch
.
randperm
(
num_seq
-
1
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
])
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
protein
:
if
keep_extra
:
protein
[
'extra_'
+
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
not_sel_seq
)
protein
[
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
sel_seq
)
return
protein
@
curry1
def
crop_extra_msa
(
protein
,
max_extra_msa
):
num_seq
=
protein
[
'extra_msa'
].
shape
[
0
]
num_sel
=
min
(
max_extra_msa
,
num_seq
)
select_indices
=
torch
.
randperm
(
num_seq
)[:
num_sel
]
for
k
in
MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
protein
[
'extra_'
+
k
]
=
torch
.
index_select
(
protein
[
'extra_'
+
k
],
0
,
select_indices
)
return
protein
def
delete_extra_msa
(
protein
):
for
k
in
MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
del
protein
[
'extra_'
+
k
]
return
protein
# Not used in inference
@
curry1
def
block_delete_msa
(
protein
,
config
):
num_seq
=
protein
[
'msa'
].
shape
[
0
]
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
if
config
.
randomize_num_blocks
:
nb
=
torch
.
distributions
.
uniform
.
Uniform
(
0
,
config
.
num_blocks
+
1
).
sample
()
else
:
nb
=
config
.
num_blocks
del_block_starts
=
torch
.
distributions
.
Uniform
(
0
,
num_seq
).
sample
(
nb
)
del_blocks
=
del_block_starts
[:,
None
]
+
torch
.
range
(
block_num_seq
)
del_blocks
=
torch
.
clip
(
del_blocks
,
0
,
num_seq
-
1
)
del_indices
=
torch
.
unique
(
torch
.
sort
(
torch
.
reshape
(
del_blocks
,
[
-
1
])))[
0
]
# Make sure we keep the original sequence
combined
=
torch
.
cat
((
torch
.
range
(
1
,
num_seq
)[
None
],
del_indices
[
None
]))
uniques
,
counts
=
combined
.
unique
(
return_counts
=
True
)
difference
=
uniques
[
counts
==
1
]
intersection
=
uniques
[
counts
>
1
]
keep_indices
=
torch
.
squeeze
(
difference
,
0
)
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
protein
:
protein
[
k
]
=
torch
.
gather
(
protein
[
k
],
keep_indices
)
return
protein
@
curry1
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.
):
weights
=
torch
.
cat
([
torch
.
ones
(
21
),
gap_agreement_weight
*
torch
.
ones
(
1
),
torch
.
zeros
(
1
)
],
0
)
# Make agreement score as weighted Hamming distance
msa_one_hot
=
make_one_hot
(
protein
[
'msa'
],
23
)
sample_one_hot
=
(
protein
[
'msa_mask'
][:,:,
None
]
*
msa_one_hot
)
extra_msa_one_hot
=
make_one_hot
(
protein
[
'extra_msa'
],
23
)
extra_one_hot
=
(
protein
[
'extra_msa_mask'
][:,:,
None
]
*
extra_msa_one_hot
)
num_seq
,
num_res
,
_
=
sample_one_hot
.
shape
extra_num_seq
,
_
,
_
=
extra_one_hot
.
shape
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement
=
torch
.
matmul
(
torch
.
reshape
(
extra_one_hot
,
[
extra_num_seq
,
num_res
*
23
]),
torch
.
reshape
(
sample_one_hot
*
weights
,
[
num_seq
,
num_res
*
23
]).
transpose
(
0
,
1
),
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein
[
'extra_cluster_assignment'
]
=
torch
.
argmax
(
agreement
,
dim
=
1
).
to
(
torch
.
int64
)
return
protein
def
unsorted_segment_sum
(
data
,
segment_ids
,
num_segments
):
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
assert
all
([
i
in
data
.
shape
for
i
in
segment_ids
.
shape
]),
"segment_ids.shape should be a prefix of data.shape"
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if
len
(
segment_ids
.
shape
)
==
1
:
s
=
torch
.
prod
(
torch
.
tensor
(
data
.
shape
[
1
:])).
long
()
segment_ids
=
segment_ids
.
repeat_interleave
(
s
).
view
(
segment_ids
.
shape
[
0
],
*
data
.
shape
[
1
:])
assert
data
.
shape
==
segment_ids
.
shape
,
"data.shape and segment_ids.shape should be equal"
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add
(
0
,
segment_ids
,
data
.
float
())
tensor
=
tensor
.
type
(
data
.
dtype
)
return
tensor
@
curry1
def
summarize_clusters
(
protein
):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq
=
protein
[
'msa'
].
shape
[
0
]
def
csum
(
x
):
return
unsorted_segment_sum
(
x
,
protein
[
'extra_cluster_assignment'
],
num_seq
)
mask
=
protein
[
'extra_msa_mask'
]
mask_counts
=
1e-6
+
protein
[
'msa_mask'
]
+
csum
(
mask
)
# Include center
msa_sum
=
csum
(
mask
[:,
:,
None
]
*
make_one_hot
(
protein
[
'extra_msa'
],
23
))
msa_sum
+=
make_one_hot
(
protein
[
'msa'
],
23
)
# Original sequence
protein
[
'cluster_profile'
]
=
msa_sum
/
mask_counts
[:,
:,
None
]
del
msa_sum
del_sum
=
csum
(
mask
*
protein
[
'extra_deletion_matrix'
])
del_sum
+=
protein
[
'deletion_matrix'
]
# Original sequence
protein
[
'cluster_deletion_mean'
]
=
del_sum
/
mask_counts
del
del_sum
return
protein
def
make_msa_mask
(
protein
):
"""Mask features are all ones, but will later be zero-padded."""
protein
[
'msa_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
,
dtype
=
torch
.
float32
)
protein
[
'msa_row_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
[
0
],
dtype
=
torch
.
float32
)
return
protein
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
"""Create pseudo beta features."""
is_gly
=
torch
.
eq
(
aatype
,
residue_constants
.
restype_order
[
'G'
])
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
cb_idx
=
residue_constants
.
atom_order
[
'CB'
]
pseudo_beta
=
torch
.
where
(
torch
.
tile
(
is_gly
[...,
None
],
[
1
]
*
len
(
is_gly
.
shape
)
+
[
3
]),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
cb_idx
,
:])
if
all_atom_masks
is
not
None
:
pseudo_beta_mask
=
torch
.
where
(
is_gly
,
all_atom_masks
[...,
ca_idx
],
all_atom_masks
[...,
cb_idx
])
return
pseudo_beta
,
pseudo_beta_mask
else
:
return
pseudo_beta
@
curry1
def
make_pseudo_beta
(
protein
,
prefix
=
''
):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert
prefix
in
[
''
,
'template_'
]
protein
[
prefix
+
'pseudo_beta'
],
protein
[
prefix
+
'pseudo_beta_mask'
]
=
(
pseudo_beta_fn
(
protein
[
'template_aatype'
if
prefix
else
'all_atom_aatype'
],
protein
[
prefix
+
'all_atom_positions'
],
protein
[
'template_all_atom_masks'
if
prefix
else
'all_atom_mask'
]))
return
protein
@
curry1
def
add_constant_field
(
protein
,
key
,
value
):
protein
[
key
]
=
torch
.
tensor
(
value
)
return
protein
def
shaped_categorical
(
probs
,
epsilon
=
1e-10
):
ds
=
probs
.
shape
num_classes
=
ds
[
-
1
]
distribution
=
torch
.
distributions
.
categorical
.
Categorical
(
torch
.
reshape
(
probs
+
epsilon
,[
-
1
,
num_classes
]))
counts
=
distribution
.
sample
()
return
torch
.
reshape
(
counts
,
ds
[:
-
1
])
def
make_hhblits_profile
(
protein
):
"""Compute the HHblits MSA profile if not already present."""
if
'hhblits_profile'
in
protein
:
return
protein
# Compute the profile for every residue (over all MSA sequences).
msa_one_hot
=
make_one_hot
(
protein
[
'msa'
],
22
)
protein
[
'hhblits_profile'
]
=
torch
.
mean
(
msa_one_hot
,
dim
=
0
)
return
protein
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.
,
0.
],
dtype
=
torch
.
float32
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
protein
[
'hhblits_profile'
]
+
config
.
same_prob
*
make_one_hot
(
protein
[
'msa'
],
22
))
# Put all remaining probability on [MASK] which is a new column
pad_shapes
=
list
(
reduce
(
add
,
[(
0
,
0
)
for
_
in
range
(
len
(
categorical_probs
.
shape
))]))
pad_shapes
[
1
]
=
1
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
assert
mask_prob
>=
0.
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
)
sh
=
protein
[
'msa'
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
'msa'
])
# Mix real and masked MSA
protein
[
'bert_mask'
]
=
mask_position
.
to
(
torch
.
float32
)
protein
[
'true_msa'
]
=
protein
[
'msa'
]
protein
[
'msa'
]
=
bert_msa
return
protein
@
curry1
def
make_fixed_size
(
protein
,
shape_schema
,
msa_cluster_size
,
extra_msa_size
,
num_res
=
0
,
num_templates
=
0
):
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map
=
{
NUM_RES
:
num_res
,
NUM_MSA_SEQ
:
msa_cluster_size
,
NUM_EXTRA_SEQ
:
extra_msa_size
,
NUM_TEMPLATES
:
num_templates
,
}
for
k
,
v
in
protein
.
items
():
# Don't transfer this to the accelerator.
if
k
==
'extra_cluster_assignment'
:
continue
shape
=
list
(
v
.
shape
)
schema
=
shape_schema
[
k
]
assert
len
(
shape
)
==
len
(
schema
),
(
f
'Rank mismatch between shape and shape schema for
{
k
}
:
{
shape
}
vs
{
schema
}
'
)
pad_size
=
[
pad_size_map
.
get
(
s2
,
None
)
or
s1
for
(
s1
,
s2
)
in
zip
(
shape
,
schema
)]
padding
=
[(
0
,
p
-
v
.
shape
[
i
])
for
i
,
p
in
enumerate
(
pad_size
)]
padding
.
reverse
()
padding
=
list
(
itertools
.
chain
(
*
padding
))
if
padding
:
protein
[
k
]
=
torch
.
nn
.
functional
.
pad
(
v
,
padding
)
protein
[
k
]
=
torch
.
reshape
(
protein
[
k
],
pad_size
)
return
protein
@
curry1
def
make_msa_feat
(
protein
):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for compatibility with domain datasets.
has_break
=
torch
.
clip
(
protein
[
'between_segment_residues'
].
to
(
torch
.
float32
),
0
,
1
)
aatype_1hot
=
make_one_hot
(
protein
[
'aatype'
],
21
)
target_feat
=
[
torch
.
unsqueeze
(
has_break
,
dim
=-
1
),
aatype_1hot
,
# Everyone gets the original sequence.
]
msa_1hot
=
make_one_hot
(
protein
[
'msa'
],
23
)
has_deletion
=
torch
.
clip
(
protein
[
'deletion_matrix'
],
0.
,
1.
)
deletion_value
=
torch
.
atan
(
protein
[
'deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
msa_feat
=
[
msa_1hot
,
torch
.
unsqueeze
(
has_deletion
,
dim
=-
1
),
torch
.
unsqueeze
(
deletion_value
,
dim
=-
1
),
]
if
'cluster_profile'
in
protein
:
deletion_mean_value
=
(
torch
.
atan
(
protein
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
))
msa_feat
.
extend
([
protein
[
'cluster_profile'
],
torch
.
unsqueeze
(
deletion_mean_value
,
dim
=-
1
),
])
if
'extra_deletion_matrix'
in
protein
:
protein
[
'extra_has_deletion'
]
=
torch
.
clip
(
protein
[
'extra_deletion_matrix'
],
0.
,
1.
)
protein
[
'extra_deletion_value'
]
=
torch
.
atan
(
protein
[
'extra_deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
protein
[
'msa_feat'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
protein
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
return
protein
@
curry1
def
select_feat
(
protein
,
feature_list
):
return
{
k
:
v
for
k
,
v
in
protein
.
items
()
if
k
in
feature_list
}
@
curry1
def
crop_templates
(
protein
,
max_templates
):
for
k
,
v
in
protein
.
items
():
if
k
.
startswith
(
'template_'
):
protein
[
k
]
=
v
[:
max_templates
]
return
protein
def
make_atom14_masks
(
protein
):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37
=
[]
restype_atom37_to_atom14
=
[]
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
restype_atom14_to_atom37
.
append
([
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
residue_constants
.
atom_types
])
# Since all 14 atoms are not present in every residue, use this mask to tell which atom is there in this residue
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_to_atom37
=
torch
.
tensor
(
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
)
restype_atom37_to_atom14
=
torch
.
tensor
(
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
)
restype_atom14_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
torch
.
float32
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
torch
.
index_select
(
restype_atom14_to_atom37
,
0
,
protein
[
'aatype'
])
residx_atom14_mask
=
torch
.
index_select
(
restype_atom14_mask
,
0
,
protein
[
'aatype'
])
protein
[
'atom14_atom_exists'
]
=
residx_atom14_mask
protein
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14
=
torch
.
index_select
(
restype_atom37_to_atom14
,
0
,
protein
[
'aatype'
])
protein
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
torch
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
torch
.
index_select
(
restype_atom37_mask
,
0
,
protein
[
'aatype'
])
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
return
protein
\ No newline at end of file
openfold/features/feature_pipeline.py
0 → 100644
View file @
961d86fc
import
copy
import
ml_collections
import
torch
from
typing
import
Mapping
,
Tuple
,
List
,
Optional
,
Dict
,
Sequence
import
numpy
as
np
from
openfold.features
import
input_pipeline
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
TensorDict
=
Dict
[
str
,
torch
.
Tensor
]
def
np_to_tensor_dict
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
features
:
Sequence
[
str
],
)
->
TensorDict
:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
return
tensor_dict
def
make_data_config
(
config
:
ml_collections
.
ConfigDict
,
num_res
:
int
,
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
cfg
=
copy
.
deepcopy
(
config
.
data
)
feature_names
=
cfg
.
common
.
unsupervised_features
if
cfg
.
common
.
use_templates
:
feature_names
+=
cfg
.
common
.
template_features
with
cfg
.
unlocked
():
cfg
.
eval
.
crop_size
=
num_res
return
cfg
,
feature_names
def
np_example_to_features
(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
random_seed
:
int
=
0
):
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
num_res
=
num_res
)
if
'deletion_matrix_int'
in
np_example
:
np_example
[
'deletion_matrix'
]
=
(
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
))
torch
.
manual_seed
(
random_seed
)
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
class
FeaturePipeline
:
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
params
:
Optional
[
Mapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]]
=
None
):
self
.
config
=
config
self
.
params
=
params
def
process_features
(
self
,
raw_features
:
FeatureDict
,
random_seed
:
int
)
->
FeatureDict
:
return
np_example_to_features
(
np_example
=
raw_features
,
config
=
self
.
config
,
random_seed
=
random_seed
)
\ No newline at end of file
openfold/features/input_pipeline.py
0 → 100644
View file @
961d86fc
import
torch
from
openfold.features
import
data_transforms
def
nonensembled_transform_fns
(
data_config
):
"""Input pipeline data transformers that are not ensembled."""
common_cfg
=
data_config
.
common
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
correct_msa_restypes
,
data_transforms
.
add_distillation_flag
(
False
),
data_transforms
.
squeeze_features
,
data_transforms
.
randomly_replace_msa_with_unknown
(
0.0
),
data_transforms
.
make_seq_mask
,
data_transforms
.
make_msa_mask
,
data_transforms
.
make_hhblits_profile
,
]
if
common_cfg
.
use_templates
:
transforms
.
extend
([
data_transforms
.
fix_templates_aatype
,
data_transforms
.
make_template_mask
,
data_transforms
.
make_pseudo_beta
(
'template_'
)
])
transforms
.
extend
([
data_transforms
.
make_atom14_masks
,
])
return
transforms
def
ensembled_transform_fns
(
data_config
):
"""Input pipeline data transformers that can be ensembled and averaged."""
common_cfg
=
data_config
.
common
eval_cfg
=
data_config
.
eval
transforms
=
[]
if
common_cfg
.
reduce_msa_clusters_by_max_templates
:
pad_msa_clusters
=
eval_cfg
.
max_msa_clusters
-
eval_cfg
.
max_templates
else
:
pad_msa_clusters
=
eval_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
transforms
.
append
(
data_transforms
.
sample_msa
(
max_msa_clusters
,
keep_extra
=
True
)
)
if
'masked_msa'
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
eval_cfg
.
masked_msa_replace_fraction
)
)
if
common_cfg
.
msa_cluster_features
:
transforms
.
append
(
data_transforms
.
nearest_neighbor_clusters
())
transforms
.
append
(
data_transforms
.
summarize_clusters
())
# Crop after creating the cluster profiles.
if
max_extra_msa
:
transforms
.
append
(
data_transforms
.
crop_extra_msa
(
max_extra_msa
))
else
:
transforms
.
append
(
data_transforms
.
delete_extra_msa
)
transforms
.
append
(
data_transforms
.
make_msa_feat
())
crop_feats
=
dict
(
eval_cfg
.
feat
)
if
eval_cfg
.
fixed_size
:
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
make_fixed_size
(
crop_feats
,
pad_msa_clusters
,
common_cfg
.
max_extra_msa
,
eval_cfg
.
crop_size
,
eval_cfg
.
max_templates
))
else
:
transforms
.
append
(
data_transforms
.
crop_templates
(
eval_cfg
.
max_templates
))
return
transforms
def
process_tensors_from_config
(
tensors
,
data_config
):
"""Based on the config, apply filters and transformations to the data."""
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
data_config
)
fn
=
compose
(
fns
)
d
[
'ensemble_index'
]
=
i
return
fn
(
d
)
eval_cfg
=
data_config
.
eval
tensors
=
compose
(
nonensembled_transform_fns
(
data_config
)
)(
tensors
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
num_ensemble
=
eval_cfg
.
num_ensemble
if
data_config
.
common
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
num_ensemble
*=
data_config
.
common
.
num_recycle
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_ensemble
))
else
:
tensors
=
tree
.
map_structure
(
lambda
x
:
x
[
None
],
tensors_0
)
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
([
dict_i
[
feat
]
for
dict_i
in
ensembles
])
return
ensembled_dict
openfold/features/mmcif_parsing.py
0 → 100644
View file @
961d86fc
"""Parses the mmCIF file format."""
import
collections
import
dataclasses
import
io
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
from
Bio
import
PDB
from
Bio.Data
import
SCOPData
# Type aliases:
ChainId
=
str
PdbHeader
=
Mapping
[
str
,
Any
]
PdbStructure
=
PDB
.
Structure
.
Structure
SeqRes
=
str
MmCIFDict
=
Mapping
[
str
,
Sequence
[
str
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Monomer
:
id
:
str
num
:
int
# Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers.
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
AtomSite
:
residue_name
:
str
author_chain_id
:
str
mmcif_chain_id
:
str
author_seq_num
:
str
mmcif_seq_num
:
int
insertion_code
:
str
hetatm_atom
:
str
model_num
:
int
# Used to map SEQRES index to a residue in the structure.
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
ResiduePosition
:
chain_id
:
str
residue_number
:
int
insertion_code
:
str
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
ResidueAtPosition
:
position
:
Optional
[
ResiduePosition
]
name
:
str
is_missing
:
bool
hetflag
:
str
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
MmcifObject
:
"""Representation of a parsed mmCIF file.
Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed.
header: Biopython header.
structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
1: ResidueAtPosition,
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id
:
str
header
:
PdbHeader
structure
:
PdbStructure
chain_to_seqres
:
Mapping
[
ChainId
,
SeqRes
]
seqres_to_structure
:
Mapping
[
ChainId
,
Mapping
[
int
,
ResidueAtPosition
]]
raw_string
:
Any
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
ParsingResult
:
"""Returned by the parse function.
Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object
:
Optional
[
MmcifObject
]
errors
:
Mapping
[
Tuple
[
str
,
str
],
Any
]
class
ParseError
(
Exception
):
"""An error indicating that an mmCIF file could not be parsed."""
def
mmcif_loop_to_list
(
prefix
:
str
,
parsed_info
:
MmCIFDict
)
->
Sequence
[
Mapping
[
str
,
str
]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
cols
=
[]
data
=
[]
for
key
,
value
in
parsed_info
.
items
():
if
key
.
startswith
(
prefix
):
cols
.
append
(
key
)
data
.
append
(
value
)
assert
all
([
len
(
xs
)
==
len
(
data
[
0
])
for
xs
in
data
]),
(
'mmCIF error: Not all loops are the same length: %s'
%
cols
)
return
[
dict
(
zip
(
cols
,
xs
))
for
xs
in
zip
(
*
data
)]
def
mmcif_loop_to_dict
(
prefix
:
str
,
index
:
str
,
parsed_info
:
MmCIFDict
,
)
->
Mapping
[
str
,
Mapping
[
str
,
str
]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries
=
mmcif_loop_to_list
(
prefix
,
parsed_info
)
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
def
parse
(
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
)
->
ParsingResult
:
"""Entry point, parses an mmcif_string.
Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.
Returns:
A ParsingResult.
"""
errors
=
{}
try
:
parser
=
PDB
.
MMCIFParser
(
QUIET
=
True
)
handle
=
io
.
StringIO
(
mmcif_string
)
full_structure
=
parser
.
get_structure
(
''
,
handle
)
first_model_structure
=
_get_first_model
(
full_structure
)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info
=
parser
.
_mmcif_dict
# pylint:disable=protected-access
# Ensure all values are lists, even if singletons.
for
key
,
value
in
parsed_info
.
items
():
if
not
isinstance
(
value
,
list
):
parsed_info
[
key
]
=
[
value
]
header
=
_get_header
(
parsed_info
)
# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains
=
_get_protein_chains
(
parsed_info
=
parsed_info
)
if
not
valid_chains
:
return
ParsingResult
(
None
,
{(
file_id
,
''
):
'No protein chains found in this file.'
})
seq_start_num
=
{
chain_id
:
min
([
monomer
.
num
for
monomer
in
seq
])
for
chain_id
,
seq
in
valid_chains
.
items
()}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# the authors / Biopython).
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id
=
{}
seq_to_structure_mappings
=
{}
for
atom
in
_get_atom_site_list
(
parsed_info
):
if
atom
.
model_num
!=
'1'
:
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id
[
atom
.
mmcif_chain_id
]
=
atom
.
author_chain_id
if
atom
.
mmcif_chain_id
in
valid_chains
:
hetflag
=
' '
if
atom
.
hetatm_atom
==
'HETATM'
:
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if
atom
.
residue_name
in
(
'HOH'
,
'WAT'
):
hetflag
=
'W'
else
:
hetflag
=
'H_'
+
atom
.
residue_name
insertion_code
=
atom
.
insertion_code
if
not
_is_set
(
atom
.
insertion_code
):
insertion_code
=
' '
position
=
ResiduePosition
(
chain_id
=
atom
.
author_chain_id
,
residue_number
=
int
(
atom
.
author_seq_num
),
insertion_code
=
insertion_code
)
seq_idx
=
int
(
atom
.
mmcif_seq_num
)
-
seq_start_num
[
atom
.
mmcif_chain_id
]
current
=
seq_to_structure_mappings
.
get
(
atom
.
author_chain_id
,
{})
current
[
seq_idx
]
=
ResidueAtPosition
(
position
=
position
,
name
=
atom
.
residue_name
,
is_missing
=
False
,
hetflag
=
hetflag
)
seq_to_structure_mappings
[
atom
.
author_chain_id
]
=
current
# Add missing residue information to seq_to_structure_mappings.
for
chain_id
,
seq_info
in
valid_chains
.
items
():
author_chain
=
mmcif_to_author_chain_id
[
chain_id
]
current_mapping
=
seq_to_structure_mappings
[
author_chain
]
for
idx
,
monomer
in
enumerate
(
seq_info
):
if
idx
not
in
current_mapping
:
current_mapping
[
idx
]
=
ResidueAtPosition
(
position
=
None
,
name
=
monomer
.
id
,
is_missing
=
True
,
hetflag
=
' '
)
author_chain_to_sequence
=
{}
for
chain_id
,
seq_info
in
valid_chains
.
items
():
author_chain
=
mmcif_to_author_chain_id
[
chain_id
]
seq
=
[]
for
monomer
in
seq_info
:
code
=
SCOPData
.
protein_letters_3to1
.
get
(
monomer
.
id
,
'X'
)
seq
.
append
(
code
if
len
(
code
)
==
1
else
'X'
)
seq
=
''
.
join
(
seq
)
author_chain_to_sequence
[
author_chain
]
=
seq
mmcif_object
=
MmcifObject
(
file_id
=
file_id
,
header
=
header
,
structure
=
first_model_structure
,
chain_to_seqres
=
author_chain_to_sequence
,
seqres_to_structure
=
seq_to_structure_mappings
,
raw_string
=
parsed_info
)
return
ParsingResult
(
mmcif_object
=
mmcif_object
,
errors
=
errors
)
except
Exception
as
e
:
# pylint:disable=broad-except
errors
[(
file_id
,
''
)]
=
e
if
not
catch_all_errors
:
raise
return
ParsingResult
(
mmcif_object
=
None
,
errors
=
errors
)
def
_get_first_model
(
structure
:
PdbStructure
)
->
PdbStructure
:
"""Returns the first model in a Biopython structure."""
return
next
(
structure
.
get_models
())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE
=
21
def
get_release_date
(
parsed_info
:
MmCIFDict
)
->
str
:
"""Returns the oldest revision date."""
revision_dates
=
parsed_info
[
'_pdbx_audit_revision_history.revision_date'
]
return
min
(
revision_dates
)
def
_get_header
(
parsed_info
:
MmCIFDict
)
->
PdbHeader
:
"""Returns a basic header containing method, release date and resolution."""
header
=
{}
experiments
=
mmcif_loop_to_list
(
'_exptl.'
,
parsed_info
)
header
[
'structure_method'
]
=
','
.
join
([
experiment
[
'_exptl.method'
].
lower
()
for
experiment
in
experiments
])
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if
'_pdbx_audit_revision_history.revision_date'
in
parsed_info
:
header
[
'release_date'
]
=
get_release_date
(
parsed_info
)
else
:
logging
.
warning
(
'Could not determine release_date: %s'
,
parsed_info
[
'_entry.id'
])
header
[
'resolution'
]
=
0.00
for
res_key
in
(
'_refine.ls_d_res_high'
,
'_em_3d_reconstruction.resolution'
,
'_reflns.d_resolution_high'
):
if
res_key
in
parsed_info
:
try
:
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
'resolution'
]
=
float
(
raw_resolution
)
except
ValueError
:
logging
.
warning
(
'Invalid resolution format: %s'
,
parsed_info
[
res_key
])
return
header
def
_get_atom_site_list
(
parsed_info
:
MmCIFDict
)
->
Sequence
[
AtomSite
]:
"""Returns list of atom sites; contains data not present in the structure."""
return
[
AtomSite
(
*
site
)
for
site
in
zip
(
# pylint:disable=g-complex-comprehension
parsed_info
[
'_atom_site.label_comp_id'
],
parsed_info
[
'_atom_site.auth_asym_id'
],
parsed_info
[
'_atom_site.label_asym_id'
],
parsed_info
[
'_atom_site.auth_seq_id'
],
parsed_info
[
'_atom_site.label_seq_id'
],
parsed_info
[
'_atom_site.pdbx_PDB_ins_code'
],
parsed_info
[
'_atom_site.group_PDB'
],
parsed_info
[
'_atom_site.pdbx_PDB_model_num'
],
)]
def
_get_protein_chains
(
*
,
parsed_info
:
Mapping
[
str
,
Any
])
->
Mapping
[
ChainId
,
Sequence
[
Monomer
]]:
"""Extracts polymer information for protein chains only.
Args:
parsed_info: _mmcif_dict produced by the Biopython parser.
Returns:
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs
=
mmcif_loop_to_list
(
'_entity_poly_seq.'
,
parsed_info
)
polymers
=
collections
.
defaultdict
(
list
)
for
entity_poly_seq
in
entity_poly_seqs
:
polymers
[
entity_poly_seq
[
'_entity_poly_seq.entity_id'
]].
append
(
Monomer
(
id
=
entity_poly_seq
[
'_entity_poly_seq.mon_id'
],
num
=
int
(
entity_poly_seq
[
'_entity_poly_seq.num'
])))
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps
=
mmcif_loop_to_dict
(
'_chem_comp.'
,
'_chem_comp.id'
,
parsed_info
)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms
=
mmcif_loop_to_list
(
'_struct_asym.'
,
parsed_info
)
entity_to_mmcif_chains
=
collections
.
defaultdict
(
list
)
for
struct_asym
in
struct_asyms
:
chain_id
=
struct_asym
[
'_struct_asym.id'
]
entity_id
=
struct_asym
[
'_struct_asym.entity_id'
]
entity_to_mmcif_chains
[
entity_id
].
append
(
chain_id
)
# Identify and return the valid protein chains.
valid_chains
=
{}
for
entity_id
,
seq_info
in
polymers
.
items
():
chain_ids
=
entity_to_mmcif_chains
[
entity_id
]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if
any
([
'peptide'
in
chem_comps
[
monomer
.
id
][
'_chem_comp.type'
]
for
monomer
in
seq_info
]):
for
chain_id
in
chain_ids
:
valid_chains
[
chain_id
]
=
seq_info
return
valid_chains
def
_is_set
(
data
:
str
)
->
bool
:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return
data
not
in
(
'.'
,
'?'
)
openfold/features/np/__init__.py
0 → 100644
View file @
961d86fc
openfold/features/np/data_pipeline.py
0 → 100644
View file @
961d86fc
import
os
import
numpy
as
np
from
typing
import
Mapping
,
Optional
,
Sequence
from
openfold.features
import
templates
,
parsers
from
openfold.features.np
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.np
import
residue_constants
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
"""Construct a feature dict of sequence features."""
features
=
{}
features
[
'aatype'
]
=
residue_constants
.
sequence_to_onehot
(
sequence
=
sequence
,
mapping
=
residue_constants
.
restype_order_with_x
,
map_unknown_to_x
=
True
)
features
[
'between_segment_residues'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'domain_name'
]
=
np
.
array
([
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'residue_index'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'seq_length'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'sequence'
]
=
np
.
array
([
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
features
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
])
->
FeatureDict
:
"""Constructs a feature dict of MSA features."""
if
not
msas
:
raise
ValueError
(
'At least one MSA must be provided.'
)
int_msa
=
[]
deletion_matrix
=
[]
seen_sequences
=
set
()
for
msa_index
,
msa
in
enumerate
(
msas
):
if
not
msa
:
raise
ValueError
(
f
'MSA
{
msa_index
}
must contain at least one sequence.'
)
for
sequence_index
,
sequence
in
enumerate
(
msa
):
if
sequence
in
seen_sequences
:
continue
seen_sequences
.
add
(
sequence
)
int_msa
.
append
(
[
residue_constants
.
HHBLITS_AA_TO_ID
[
res
]
for
res
in
sequence
]
)
deletion_matrix
.
append
(
deletion_matrices
[
msa_index
][
sequence_index
])
num_res
=
len
(
msas
[
0
][
0
])
num_alignments
=
len
(
int_msa
)
features
=
{}
features
[
'deletion_matrix_int'
]
=
np
.
array
(
deletion_matrix
,
dtype
=
np
.
int32
)
features
[
'msa'
]
=
np
.
array
(
int_msa
,
dtype
=
np
.
int32
)
features
[
'num_alignments'
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
return
features
class
DataPipeline
:
"""Runs the alignment tools and assembles the input features."""
def
__init__
(
self
,
jackhmmer_binary_path
:
str
,
hhblits_binary_path
:
str
,
hhsearch_binary_path
:
str
,
uniref90_database_path
:
str
,
mgnify_database_path
:
str
,
bfd_database_path
:
Optional
[
str
],
uniclust30_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
pdb70_database_path
:
str
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
mgnify_max_hits
:
int
=
501
,
uniref_max_hits
:
int
=
10000
):
"""Constructs a feature dict for a given FASTA file."""
self
.
_use_small_bfd
=
use_small_bfd
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniref90_database_path
)
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
small_bfd_database_path
)
else
:
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
[
bfd_database_path
,
uniclust30_database_path
]
)
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
)
self
.
template_featurizer
=
template_featurizer
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
)
->
FeatureDict
:
"""Runs alignment tools on the input sequence and creates features."""
with
open
(
input_fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
input_fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
input_fasta_path
)[
0
]
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
input_fasta_path
)[
0
]
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_uniref90_result
[
'sto'
],
max_sequences
=
self
.
uniref_max_hits
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
uniref90_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'uniref90_hits.sto'
)
with
open
(
uniref90_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_uniref90_result
[
'sto'
])
mgnify_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'mgnify_hits.so'
)
with
open
(
mgnify_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_mgnify_result
[
'sto'
])
pdb70_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'pdb70_hits.hhr'
)
with
open
(
pdb70_out_path
,
'w'
)
as
f
:
f
.
write
(
hhsearch_result
)
uniref90_msa
,
uniref90_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_uniref90_result
[
'sto'
]
)
mgnify_msa
,
mgnify_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_mgnify_result
[
'sto'
]
)
hhsearch_hits
=
parsers
.
parse_hhr
(
hhsearch_result
)
mgnify_msa
=
mgnify_msa
[:
self
.
mgnify_max_hits
]
mgnify_deletion_matrix
=
mgnify_deletion_matrix
[:
self
.
mgnify_max_hits
]
if
self
.
_use_small_bfd
:
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
input_fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'small_bfd_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
'sto'
])
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
]
)
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
input_fasta_path
)
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
hhblits_bfd_uniclust_result
[
'a3m'
]
)
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
None
,
hits
=
hhsearch_hits
)
sequence_features
=
make_sequence_features
(
sequence
=
input_sequence
,
description
=
input_description
,
num_res
=
num_res
)
msa_features
=
make_msa_features
(
msas
=
(
uniref90_msa
,
bfd_msa
,
mgnify_msa
),
deletion_matrices
=
(
uniref90_deletion_matrix
,
bfd_deletion_matrix
,
mgnify_deletion_matrix
)
)
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
features
}
openfold/features/np/hhblits.py
0 → 100644
View file @
961d86fc
"""Library to run HHblits from Python."""
import
glob
import
os
import
subprocess
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
from
absl
import
logging
from
openfold.features.np
import
utils
_HHBLITS_DEFAULT_P
=
20
_HHBLITS_DEFAULT_Z
=
500
class
HHBlits
:
"""Python wrapper of the HHblits binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
n_cpu
:
int
=
4
,
n_iter
:
int
=
3
,
e_value
:
float
=
0.001
,
maxseq
:
int
=
1_000_000
,
realign_max
:
int
=
100_000
,
maxfilt
:
int
=
100_000
,
min_prefilter_hits
:
int
=
1000
,
all_seqs
:
bool
=
False
,
alt
:
Optional
[
int
]
=
None
,
p
:
int
=
_HHBLITS_DEFAULT_P
,
z
:
int
=
_HHBLITS_DEFAULT_Z
):
"""Initializes the Python HHblits wrapper.
Args:
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to give HHblits.
n_iter: The number of HHblits iterations.
e_value: The E-value, see HHblits docs for more details.
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 20000.
min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: 100.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
HHblits default: False.
alt: Show up to this many alternative alignments.
p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
Raises:
RuntimeError: If HHblits binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
databases
=
databases
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
logging
.
error
(
'Could not find HHBlits database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find HHBlits database
{
database_path
}
'
)
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
self
.
e_value
=
e_value
self
.
maxseq
=
maxseq
self
.
realign_max
=
realign_max
self
.
maxfilt
=
maxfilt
self
.
min_prefilter_hits
=
min_prefilter_hits
self
.
all_seqs
=
all_seqs
self
.
alt
=
alt
self
.
p
=
p
self
.
z
=
z
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
'/tmp'
)
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.a3m'
)
db_cmd
=
[]
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'-d'
)
db_cmd
.
append
(
db_path
)
cmd
=
[
self
.
binary_path
,
'-i'
,
input_fasta_path
,
'-cpu'
,
str
(
self
.
n_cpu
),
'-oa3m'
,
a3m_path
,
'-o'
,
'/dev/null'
,
'-n'
,
str
(
self
.
n_iter
),
'-e'
,
str
(
self
.
e_value
),
'-maxseq'
,
str
(
self
.
maxseq
),
'-realign_max'
,
str
(
self
.
realign_max
),
'-maxfilt'
,
str
(
self
.
maxfilt
),
'-min_prefilter_hits'
,
str
(
self
.
min_prefilter_hits
)]
if
self
.
all_seqs
:
cmd
+=
[
'-all'
]
if
self
.
alt
:
cmd
+=
[
'-alt'
,
str
(
self
.
alt
)]
if
self
.
p
!=
_HHBLITS_DEFAULT_P
:
cmd
+=
[
'-p'
,
str
(
self
.
p
)]
if
self
.
z
!=
_HHBLITS_DEFAULT_Z
:
cmd
+=
[
'-Z'
,
str
(
self
.
z
)]
cmd
+=
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'HHblits query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
# Logs have a 15k character limit, so log HHblits error line by line.
logging
.
error
(
'HHblits failed. HHblits stderr begin:'
)
for
error_line
in
stderr
.
decode
(
'utf-8'
).
splitlines
():
if
error_line
.
strip
():
logging
.
error
(
error_line
.
strip
())
logging
.
error
(
'HHblits stderr end'
)
raise
RuntimeError
(
'HHblits failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
[:
500_000
].
decode
(
'utf-8'
)))
with
open
(
a3m_path
)
as
f
:
a3m
=
f
.
read
()
raw_output
=
dict
(
a3m
=
a3m
,
output
=
stdout
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
return
raw_output
openfold/features/np/hhsearch.py
0 → 100644
View file @
961d86fc
"""Library to run HHsearch from Python."""
import
glob
import
os
import
subprocess
from
typing
import
Sequence
from
absl
import
logging
from
openfold.features.np
import
utils
class
HHSearch
:
"""Python wrapper of the HHsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
maxseq
:
int
=
1_000_000
):
"""Initializes the Python HHsearch wrapper.
Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
databases
=
databases
self
.
maxseq
=
maxseq
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
logging
.
error
(
'Could not find HHsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find HHsearch database
{
database_path
}
'
)
def
query
(
self
,
a3m
:
str
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
'/tmp'
)
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.a3m'
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hhr'
)
with
open
(
input_path
,
'w'
)
as
f
:
f
.
write
(
a3m
)
db_cmd
=
[]
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'-d'
)
db_cmd
.
append
(
db_path
)
cmd
=
[
self
.
binary_path
,
'-i'
,
input_path
,
'-o'
,
hhr_path
,
'-maxseq'
,
str
(
self
.
maxseq
)
]
+
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'HHsearch query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
# Stderr is truncated to prevent proto size errors in Beam.
raise
RuntimeError
(
'HHSearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
[:
100_000
].
decode
(
'utf-8'
)))
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
return
hhr
openfold/features/np/jackhmmer.py
0 → 100644
View file @
961d86fc
"""Library to run Jackhmmer from Python."""
from
concurrent
import
futures
import
glob
import
os
import
subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
absl
import
logging
from
openfold.features.np
import
utils
class
Jackhmmer
:
"""Python wrapper of the Jackhmmer binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
database_path
:
str
,
n_cpu
:
int
=
8
,
n_iter
:
int
=
1
,
e_value
:
float
=
0.0001
,
z_value
:
Optional
[
int
]
=
None
,
get_tblout
:
bool
=
False
,
filter_f1
:
float
=
0.0005
,
filter_f2
:
float
=
0.00005
,
filter_f3
:
float
=
0.0000005
,
incdom_e
:
Optional
[
float
]
=
None
,
dom_e
:
Optional
[
float
]
=
None
,
num_streamed_chunks
:
Optional
[
int
]
=
None
,
streaming_callback
:
Optional
[
Callable
[[
int
],
None
]]
=
None
):
"""Initializes the Python Jackhmmer wrapper.
Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format).
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value, see Jackhmmer docs for more details.
get_tblout: Whether to save tblout string.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round.
dom_e: Domain e-value criteria for inclusion in tblout.
num_streamed_chunks: Number of database chunks to stream over.
streaming_callback: Callback function run after each chunk iteration with
the iteration number as argument.
"""
self
.
binary_path
=
binary_path
self
.
database_path
=
database_path
self
.
num_streamed_chunks
=
num_streamed_chunks
if
not
os
.
path
.
exists
(
self
.
database_path
)
and
num_streamed_chunks
is
None
:
logging
.
error
(
'Could not find Jackhmmer database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find Jackhmmer database
{
database_path
}
'
)
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
self
.
e_value
=
e_value
self
.
z_value
=
z_value
self
.
filter_f1
=
filter_f1
self
.
filter_f2
=
filter_f2
self
.
filter_f3
=
filter_f3
self
.
incdom_e
=
incdom_e
self
.
dom_e
=
dom_e
self
.
get_tblout
=
get_tblout
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
'/tmp'
)
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.sto'
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# speeds up the pipeline at the expensive of sensitivity. They are
# currently set very low to make querying Mgnify run in a reasonable
# amount of time.
cmd_flags
=
[
# Don't pollute stdout with Jackhmmer output.
'-o'
,
'/dev/null'
,
'-A'
,
sto_path
,
'--noali'
,
'--F1'
,
str
(
self
.
filter_f1
),
'--F2'
,
str
(
self
.
filter_f2
),
'--F3'
,
str
(
self
.
filter_f3
),
'--incE'
,
str
(
self
.
e_value
),
# Report only sequences with E-values <= x in per-sequence output.
'-E'
,
str
(
self
.
e_value
),
'--cpu'
,
str
(
self
.
n_cpu
),
'-N'
,
str
(
self
.
n_iter
)
]
if
self
.
get_tblout
:
tblout_path
=
os
.
path
.
join
(
query_tmp_dir
,
'tblout.txt'
)
cmd_flags
.
extend
([
'--tblout'
,
tblout_path
])
if
self
.
z_value
:
cmd_flags
.
extend
([
'-Z'
,
str
(
self
.
z_value
)])
if
self
.
dom_e
is
not
None
:
cmd_flags
.
extend
([
'--domE'
,
str
(
self
.
dom_e
)])
if
self
.
incdom_e
is
not
None
:
cmd_flags
.
extend
([
'--incdomE'
,
str
(
self
.
incdom_e
)])
cmd
=
[
self
.
binary_path
]
+
cmd_flags
+
[
input_fasta_path
,
database_path
]
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'Jackhmmer (
{
os
.
path
.
basename
(
database_path
)
}
) query'
):
_
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'Jackhmmer failed
\n
stderr:
\n
%s
\n
'
%
stderr
.
decode
(
'utf-8'
))
# Get e-values for each target name
tbl
=
''
if
self
.
get_tblout
:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
raw_output
=
dict
(
sto
=
sto
,
tbl
=
tbl
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
'
{
self
.
database_path
}
.
{
db_idx
}
'
db_local_chunk
=
lambda
db_idx
:
f
'/tmp/ramdisk/
{
db_basename
}
.
{
db_idx
}
'
# Remove existing files to prevent OOM
for
f
in
glob
.
glob
(
db_local_chunk
(
'[0-9]*'
)):
try
:
os
.
remove
(
f
)
except
OSError
:
print
(
f
'OSError while deleting
{
f
}
'
)
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
as
executor
:
chunked_output
=
[]
for
i
in
range
(
1
,
self
.
num_streamed_chunks
+
1
):
# Copy the chunk locally
if
i
==
1
:
future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
),
db_local_chunk
(
i
))
if
i
<
self
.
num_streamed_chunks
:
next_future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
+
1
),
db_local_chunk
(
i
+
1
))
# Run Jackhmmer with the chunk
future
.
result
()
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
)))
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
future
=
next_future
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
return
chunked_output
openfold/features/np/kalign.py
0 → 100644
View file @
961d86fc
"""A Python wrapper for Kalign."""
import
os
import
subprocess
from
typing
import
Sequence
from
absl
import
logging
from
openfold.features.np
import
utils
def
_to_a3m
(
sequences
:
Sequence
[
str
])
->
str
:
"""Converts sequences to an a3m file."""
names
=
[
'sequence %d'
%
i
for
i
in
range
(
1
,
len
(
sequences
)
+
1
)]
a3m
=
[]
for
sequence
,
name
in
zip
(
sequences
,
names
):
a3m
.
append
(
u
'>'
+
name
+
u
'
\n
'
)
a3m
.
append
(
sequence
+
u
'
\n
'
)
return
''
.
join
(
a3m
)
class
Kalign
:
"""Python wrapper of the Kalign binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
):
"""Initializes the Python Kalign wrapper.
Args:
binary_path: The path to the Kalign binary.
Raises:
RuntimeError: If Kalign binary not found within the path.
"""
self
.
binary_path
=
binary_path
def
align
(
self
,
sequences
:
Sequence
[
str
])
->
str
:
"""Aligns the sequences and returns the alignment in A3M string.
Args:
sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as
different alignment tree might get constructed.
Returns:
A string with the alignment in a3m format.
Raises:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging
.
info
(
'Aligning %d sequences'
,
len
(
sequences
))
for
s
in
sequences
:
if
len
(
s
)
<
6
:
raise
ValueError
(
'Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).'
%
(
s
,
len
(
s
)))
with
utils
.
tmpdir_manager
(
base_dir
=
'/tmp'
)
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
'input.fasta'
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.a3m'
)
with
open
(
input_fasta_path
,
'w'
)
as
f
:
f
.
write
(
_to_a3m
(
sequences
))
cmd
=
[
self
.
binary_path
,
'-i'
,
input_fasta_path
,
'-o'
,
output_a3m_path
,
'-format'
,
'fasta'
,
]
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'Kalign query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'Kalign stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'Kalign failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_a3m_path
)
as
f
:
a3m
=
f
.
read
()
return
a3m
openfold/features/np/utils.py
0 → 100644
View file @
961d86fc
"""Common utilities for data pipeline tools."""
import
contextlib
import
shutil
import
tempfile
import
time
from
typing
import
Optional
from
absl
import
logging
@
contextlib
.
contextmanager
def
tmpdir_manager
(
base_dir
:
Optional
[
str
]
=
None
):
"""Context manager that deletes a temporary directory on exit."""
tmpdir
=
tempfile
.
mkdtemp
(
dir
=
base_dir
)
try
:
yield
tmpdir
finally
:
shutil
.
rmtree
(
tmpdir
,
ignore_errors
=
True
)
@
contextlib
.
contextmanager
def
timing
(
msg
:
str
):
logging
.
info
(
'Started %s'
,
msg
)
tic
=
time
.
time
()
yield
toc
=
time
.
time
()
logging
.
info
(
'Finished %s in %.3f seconds'
,
msg
,
toc
-
tic
)
openfold/features/parsers.py
0 → 100644
View file @
961d86fc
"""Functions for parsing various file formats."""
import
collections
import
dataclasses
import
re
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateHit
:
"""Class representing a template hit."""
index
:
int
name
:
str
aligned_cols
:
int
sum_probs
:
float
query
:
str
hit_sequence
:
str
indices_query
:
List
[
int
]
indices_hit
:
List
[
int
]
def
parse_fasta
(
fasta_string
:
str
)
->
Tuple
[
Sequence
[
str
],
Sequence
[
str
]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.
Arguments:
fasta_string: The string contents of a FASTA file.
Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences
=
[]
descriptions
=
[]
index
=
-
1
for
line
in
fasta_string
.
splitlines
():
line
=
line
.
strip
()
if
line
.
startswith
(
'>'
):
index
+=
1
descriptions
.
append
(
line
[
1
:])
# Remove the '>' at the beginning.
sequences
.
append
(
''
)
continue
elif
not
line
:
continue
# Skip blank lines.
sequences
[
index
]
+=
line
return
sequences
,
descriptions
def
parse_stockholm
(
stockholm_string
:
str
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
stockholm_string: The string contents of a stockholm file. The first
sequence in the file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* The names of the targets matched, including the jackhmmer subsequence
suffix.
"""
name_to_sequence
=
collections
.
OrderedDict
()
for
line
in
stockholm_string
.
splitlines
():
line
=
line
.
strip
()
if
not
line
or
line
.
startswith
((
'#'
,
'//'
)):
continue
name
,
sequence
=
line
.
split
()
if
name
not
in
name_to_sequence
:
name_to_sequence
[
name
]
=
''
name_to_sequence
[
name
]
+=
sequence
msa
=
[]
deletion_matrix
=
[]
query
=
''
keep_columns
=
[]
for
seq_index
,
sequence
in
enumerate
(
name_to_sequence
.
values
()):
if
seq_index
==
0
:
# Gather the columns with gaps from the query
query
=
sequence
keep_columns
=
[
i
for
i
,
res
in
enumerate
(
query
)
if
res
!=
'-'
]
# Remove the columns with gaps in the query from all sequences.
aligned_sequence
=
''
.
join
([
sequence
[
c
]
for
c
in
keep_columns
])
msa
.
append
(
aligned_sequence
)
# Count the number of deletions w.r.t. query.
deletion_vec
=
[]
deletion_count
=
0
for
seq_res
,
query_res
in
zip
(
sequence
,
query
):
if
seq_res
!=
'-'
or
query_res
!=
'-'
:
if
query_res
==
'-'
:
deletion_count
+=
1
else
:
deletion_vec
.
append
(
deletion_count
)
deletion_count
=
0
deletion_matrix
.
append
(
deletion_vec
)
return
msa
,
deletion_matrix
,
list
(
name_to_sequence
.
keys
())
def
parse_a3m
(
a3m_string
:
str
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
]:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
sequences
,
_
=
parse_fasta
(
a3m_string
)
deletion_matrix
=
[]
for
msa_sequence
in
sequences
:
deletion_vec
=
[]
deletion_count
=
0
for
j
in
msa_sequence
:
if
j
.
islower
():
deletion_count
+=
1
else
:
deletion_vec
.
append
(
deletion_count
)
deletion_count
=
0
deletion_matrix
.
append
(
deletion_vec
)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table
=
str
.
maketrans
(
''
,
''
,
string
.
ascii_lowercase
)
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
return
aligned_sequences
,
deletion_matrix
def
_convert_sto_seq_to_a3m
(
query_non_gaps
:
Sequence
[
bool
],
sto_seq
:
str
)
->
Iterable
[
str
]:
for
is_query_res_non_gap
,
sequence_res
in
zip
(
query_non_gaps
,
sto_seq
):
if
is_query_res_non_gap
:
yield
sequence_res
elif
sequence_res
!=
'-'
:
yield
sequence_res
.
lower
()
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
str
:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions
=
{}
sequences
=
{}
reached_max_sequences
=
False
for
line
in
stockholm_format
.
splitlines
():
reached_max_sequences
=
max_sequences
and
len
(
sequences
)
>=
max_sequences
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname
,
aligned_seq
=
line
.
split
(
maxsplit
=
1
)
if
seqname
not
in
sequences
:
if
reached_max_sequences
:
continue
sequences
[
seqname
]
=
''
sequences
[
seqname
]
+=
aligned_seq
for
line
in
stockholm_format
.
splitlines
():
if
line
[:
4
]
==
'#=GS'
:
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns
=
line
.
split
(
maxsplit
=
3
)
seqname
,
feature
=
columns
[
1
:
3
]
value
=
columns
[
3
]
if
len
(
columns
)
==
4
else
''
if
feature
!=
'DE'
:
continue
if
reached_max_sequences
and
seqname
not
in
sequences
:
continue
descriptions
[
seqname
]
=
value
if
len
(
descriptions
)
==
len
(
sequences
):
break
# Convert sto format to a3m line by line
a3m_sequences
=
{}
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
'-'
for
res
in
query_sequence
]
for
seqname
,
sto_sequence
in
sequences
.
items
():
a3m_sequences
[
seqname
]
=
''
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
))
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
for
k
in
a3m_sequences
)
return
'
\n
'
.
join
(
fasta_chunks
)
+
'
\n
'
# Include terminating newline.
def
_get_hhr_line_regex_groups
(
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
match
=
re
.
match
(
regex_pattern
,
line
)
if
match
is
None
:
raise
RuntimeError
(
f
'Could not parse query line
{
line
}
'
)
return
match
.
groups
()
def
_update_hhr_residue_indices_list
(
sequence
:
str
,
start_index
:
int
,
indices_list
:
List
[
int
]):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter
=
start_index
for
symbol
in
sequence
:
if
symbol
==
'-'
:
indices_list
.
append
(
-
1
)
else
:
indices_list
.
append
(
counter
)
counter
+=
1
def
_parse_hhr_hit
(
detailed_lines
:
Sequence
[
str
])
->
TemplateHit
:
"""Parses the detailed HMM HMM comparison section for a single Hit.
This works on .hhr files generated from both HHBlits and HHSearch.
Args:
detailed_lines: A list of lines from a single comparison section between 2
sequences (which each have their own HMM's)
Returns:
A dictionary with the information from that detailed comparison section
Raises:
RuntimeError: If a certain line cannot be processed
"""
# Parse first 2 lines.
number_of_hit
=
int
(
detailed_lines
[
0
].
split
()[
-
1
])
name_hit
=
detailed_lines
[
1
][
1
:]
# Parse the summary line.
pattern
=
(
'Probab=(.*)[
\t
]*E-value=(.*)[
\t
]*Score=(.*)[
\t
]*Aligned_cols=(.*)[
\t
'
' ]*Identities=(.*)%[
\t
]*Similarity=(.*)[
\t
]*Sum_probs=(.*)[
\t
'
']*Template_Neff=(.*)'
)
match
=
re
.
match
(
pattern
,
detailed_lines
[
2
])
if
match
is
None
:
raise
RuntimeError
(
'Could not parse section: %s. Expected this:
\n
%s to contain summary.'
%
(
detailed_lines
,
detailed_lines
[
2
]))
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
float
(
x
)
for
x
in
match
.
groups
()]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query
=
''
hit_sequence
=
''
indices_query
=
[]
indices_hit
=
[]
length_block
=
None
for
line
in
detailed_lines
[
3
:]:
# Parse the query sequence line
if
(
line
.
startswith
(
'Q '
)
and
not
line
.
startswith
(
'Q ss_dssp'
)
and
not
line
.
startswith
(
'Q ss_pred'
)
and
not
line
.
startswith
(
'Q Consensus'
)):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt
=
r
'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
# Get the length of the parsed block using the start and finish indices,
# and ensure it is the same as the actual block length.
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
delta_query
=
groups
[
1
]
end
=
int
(
groups
[
2
])
num_insertions
=
len
([
x
for
x
in
delta_query
if
x
==
'-'
])
length_block
=
end
-
start
+
num_insertions
assert
length_block
==
len
(
delta_query
)
# Update the query sequence and indices list.
query
+=
delta_query
_update_hhr_residue_indices_list
(
delta_query
,
start
,
indices_query
)
elif
line
.
startswith
(
'T '
):
# Parse the hit sequence.
if
(
not
line
.
startswith
(
'T ss_dssp'
)
and
not
line
.
startswith
(
'T ss_pred'
)
and
not
line
.
startswith
(
'T Consensus'
)):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt
=
r
'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
delta_hit_sequence
=
groups
[
1
]
assert
length_block
==
len
(
delta_hit_sequence
)
# Update the hit sequence and indices list.
hit_sequence
+=
delta_hit_sequence
_update_hhr_residue_indices_list
(
delta_hit_sequence
,
start
,
indices_hit
)
return
TemplateHit
(
index
=
number_of_hit
,
name
=
name_hit
,
aligned_cols
=
int
(
aligned_cols
),
sum_probs
=
sum_probs
,
query
=
query
,
hit_sequence
=
hit_sequence
,
indices_query
=
indices_query
,
indices_hit
=
indices_hit
,
)
def
parse_hhr
(
hhr_string
:
str
)
->
Sequence
[
TemplateHit
]:
"""Parses the content of an entire HHR file."""
lines
=
hhr_string
.
splitlines
()
# Each .hhr file starts with a results table, then has a sequence of hit
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts
=
[
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
startswith
(
'No '
)]
hits
=
[]
if
block_starts
:
block_starts
.
append
(
len
(
lines
))
# Add the end of the final block.
for
i
in
range
(
len
(
block_starts
)
-
1
):
hits
.
append
(
_parse_hhr_hit
(
lines
[
block_starts
[
i
]:
block_starts
[
i
+
1
]]))
return
hits
def
parse_e_values_from_tblout
(
tblout
:
str
)
->
Dict
[
str
,
float
]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values
=
{
'query'
:
0
}
lines
=
[
line
for
line
in
tblout
.
splitlines
()
if
line
[
0
]
!=
'#'
]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
for
line
in
lines
:
fields
=
line
.
split
()
e_value
=
fields
[
4
]
target_name
=
fields
[
0
]
e_values
[
target_name
]
=
float
(
e_value
)
return
e_values
openfold/features/templates.py
0 → 100644
View file @
961d86fc
This diff is collapsed.
Click to expand it.
openfold/utils/affine_utils.py
View file @
961d86fc
...
@@ -198,7 +198,7 @@ class T:
...
@@ -198,7 +198,7 @@ class T:
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e0
))
+
eps
)
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e0
))
+
eps
)
e0
=
[
c
/
denom
for
c
in
e0
]
e0
=
[
c
/
denom
for
c
in
e0
]
dot
=
sum
((
c1
*
c2
for
c1
,
c2
in
zip
(
e0
,
e1
)))
dot
=
sum
((
c1
*
c2
for
c1
,
c2
in
zip
(
e0
,
e1
)))
e1
=
[
c
1
-
c
2
*
dot
for
c1
,
c2
in
zip
(
e
1
,
e
0
)]
e1
=
[
c
2
-
c
1
*
dot
for
c1
,
c2
in
zip
(
e
0
,
e
1
)]
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e1
))
+
eps
)
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e1
))
+
eps
)
e1
=
[
c
/
denom
for
c
in
e1
]
e1
=
[
c
/
denom
for
c
in
e1
]
e2
=
[
e2
=
[
...
...
run_pretrained_alphafold.py
View file @
961d86fc
...
@@ -14,18 +14,22 @@
...
@@ -14,18 +14,22 @@
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
math
import
pickle
import
pickle
import
os
import
os
# A hack to get OpenMM and PyTorch to peacefully coexist
# A hack to get OpenMM and PyTorch to peacefully coexist
import
random
import
sys
from
openfold.features
import
templates
,
feature_pipeline
from
openfold.features.np
import
data_pipeline
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
from
config
import
model_config
from
config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
...
@@ -35,12 +39,13 @@ from openfold.utils.import_weights import (
...
@@ -35,12 +39,13 @@ from openfold.utils.import_weights import (
import_jax_weights_
,
import_jax_weights_
,
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
)
)
FEAT_PATH
=
"tests/test_data/sample_feats.pickle"
FEAT_PATH
=
"tests/test_data/sample_feats.pickle"
MAX_TEMPLATE_HITS
=
20
def
main
(
args
):
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
.
model
)
model
=
AlphaFold
(
config
.
model
)
...
@@ -48,9 +53,56 @@ def main(args):
...
@@ -48,9 +53,56 @@ def main(args):
import_jax_weights_
(
model
,
args
.
param_path
)
import_jax_weights_
(
model
,
args
.
param_path
)
model
=
model
.
to
(
args
.
device
)
model
=
model
.
to
(
args
.
device
)
with
open
(
FEAT_PATH
,
"rb"
)
as
f
:
# FEATURE COLLECTION AND PROCESSING
batch
=
pickle
.
load
(
f
)
use_small_bfd
=
args
.
preset
==
"reduced_dbs"
num_ensemble
=
1
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
MAX_TEMPLATE_HITS
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
None
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
data_processor
=
data_pipeline
.
DataPipeline
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
small_bfd_database_path
=
args
.
small_bfd_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
template_featurizer
=
template_featurizer
,
use_small_bfd
=
use_small_bfd
)
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
config
.
data
.
eval
.
num_ensemble
=
num_ensemble
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
msa_output_dir
=
os
.
path
.
join
(
output_dir_base
,
"msas"
)
if
not
os
.
path
.
exists
(
msa_output_dir
):
os
.
makedirs
(
msa_output_dir
)
print
(
"Collecting data..."
)
feature_dict
=
data_processor
.
process
(
input_fasta_path
=
args
.
fasta_path
,
msa_output_dir
=
msa_output_dir
)
print
(
"Generating features..."
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
random_seed
)
print
(
"Executing model..."
)
batch
=
processed_feature_dict
with
torch
.
no_grad
():
with
torch
.
no_grad
():
batch
=
{
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
device
)
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
device
)
...
@@ -117,6 +169,14 @@ def main(args):
...
@@ -117,6 +169,14 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--fasta_path"
,
type
=
str
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
required
=
True
)
parser
.
add_argument
(
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cpu"
,
"--device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
help
=
"""Name of the device on which to run the model. Any valid torch
...
@@ -127,16 +187,58 @@ if __name__ == "__main__":
...
@@ -127,16 +187,58 @@ if __name__ == "__main__":
help
=
"""Name of a model config. Choose one of model_{1-5} or
help
=
"""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
)
parser
.
add_argument
(
parser
.
add_argument
(
"--param_path"
,
type
=
str
,
default
=
None
,
"--param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to model parameters. If None, parameters are selected
help
=
"""Path to model parameters. If None, parameters are selected
automatically according to the model name from
automatically according to the model name from
openfold/resources/params"""
openfold/resources/params"""
)
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--uniref90_database_path'
,
type
=
str
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
'--mgnify_database_path'
,
type
=
str
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
'--small_bfd_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--uniclust30_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--pdb70_database_path'
,
type
=
str
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
'--template_mmcif_dir'
,
type
=
str
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
default
=
None
,
required
=
True
)
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--preset'
,
type
=
str
,
default
=
'full_dbs'
,
required
=
True
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
)
parser
.
add_argument
(
'--random_seed'
,
type
=
str
,
default
=
None
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment