Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
444c548a
Unverified
Commit
444c548a
authored
Sep 06, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Sep 06, 2022
Browse files
add hmmsearch (#58)
parent
9c0e7519
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1167 additions
and
111 deletions
+1167
-111
environment.yml
environment.yml
+1
-0
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+278
-5
fastfold/data/feature_processing_multimer.py
fastfold/data/feature_processing_multimer.py
+244
-0
fastfold/data/msa_pairing.py
fastfold/data/msa_pairing.py
+1
-1
fastfold/data/parsers.py
fastfold/data/parsers.py
+278
-11
fastfold/data/templates.py
fastfold/data/templates.py
+115
-0
fastfold/data/tools/hmmbuild.py
fastfold/data/tools/hmmbuild.py
+137
-0
fastfold/data/tools/hmmsearch.py
fastfold/data/tools/hmmsearch.py
+38
-49
inference.py
inference.py
+75
-45
No files found.
environment.yml
View file @
444c548a
...
...
@@ -16,6 +16,7 @@ dependencies:
-
typing-extensions==3.10.0.2
-
einops
-
colossalai
-
pandas
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchvision==0.13.1
...
...
fastfold/data/data_pipeline.py
View file @
444c548a
...
...
@@ -14,19 +14,34 @@
# limitations under the License.
import
os
import
collections
import
contextlib
import
dataclasses
import
datetime
import
json
from
multiprocessing
import
cpu_count
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
import
tempfile
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
import
numpy
as
np
from
fastfold.data
import
templates
,
parsers
,
mmcif_parsing
from
fastfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
fastfold.data
import
(
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
,
)
from
fastfold.data.parsers
import
Msa
from
fastfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
fastfold.data.tools.utils
import
to_date
from
fastfold.common
import
residue_constants
,
protein
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
def
empty_template_feats
(
n_res
)
->
FeatureDict
:
return
{
...
...
@@ -216,6 +231,25 @@ def make_msa_features(
return
features
def
run_msa_tool
(
msa_runner
,
fasta_path
:
str
,
msa_out_path
:
str
,
msa_format
:
str
,
max_sto_sequences
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
result
=
msa_runner
.
query
(
fasta_path
)[
0
]
with
open
(
msa_out_path
,
"w"
)
as
f
:
f
.
write
(
result
[
msa_format
])
return
result
class
AlignmentRunner
:
"""Runs alignment tools and saves the results"""
def
__init__
(
...
...
@@ -228,10 +262,12 @@ class AlignmentRunner:
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
pdb70_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
uniprot_max_hits
:
int
=
50000
,
):
"""
Args:
...
...
@@ -411,6 +447,120 @@ class AlignmentRunner:
f
.
write
(
hhblits_bfd_uniclust_result
[
"a3m"
])
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
fasta_file
.
write
(
fasta_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
def
convert_monomer_features
(
monomer_features
:
FeatureDict
,
chain_id
:
str
)
->
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object_
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
for
feature_name
,
feature
in
monomer_features
.
items
():
if
feature_name
in
unnecessary_leading_dim_feats
:
# asarray ensures it's a np.ndarray.
feature
=
np
.
asarray
(
feature
[
0
],
dtype
=
feature
.
dtype
)
elif
feature_name
==
'aatype'
:
# The multimer model performs the one-hot operation itself.
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
elif
feature_name
==
'template_aatype'
:
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature
=
np
.
take
(
new_order_list
,
feature
.
astype
(
np
.
int32
),
axis
=
0
)
elif
feature_name
==
'template_all_atom_masks'
:
feature_name
=
'template_all_atom_mask'
converted
[
feature_name
]
=
feature
return
converted
def
int_id_to_str_id
(
num
:
int
)
->
str
:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if
num
<=
0
:
raise
ValueError
(
f
'Only positive integers allowed, got
{
num
}
.'
)
num
=
num
-
1
# 1-based indexing.
output
=
[]
while
num
>=
0
:
output
.
append
(
chr
(
num
%
26
+
ord
(
'A'
)))
num
=
num
//
26
-
1
return
''
.
join
(
output
)
def
add_assembly_features
(
all_chain_features
:
MutableMapping
[
str
,
FeatureDict
],
)
->
MutableMapping
[
str
,
FeatureDict
]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id
=
{}
grouped_chains
=
collections
.
defaultdict
(
list
)
for
chain_id
,
chain_features
in
all_chain_features
.
items
():
seq
=
str
(
chain_features
[
'sequence'
])
if
seq
not
in
seq_to_entity_id
:
seq_to_entity_id
[
seq
]
=
len
(
seq_to_entity_id
)
+
1
grouped_chains
[
seq_to_entity_id
[
seq
]].
append
(
chain_features
)
new_all_chain_features
=
{}
chain_id
=
1
for
entity_id
,
group_chain_features
in
grouped_chains
.
items
():
for
sym_id
,
chain_features
in
enumerate
(
group_chain_features
,
start
=
1
):
new_all_chain_features
[
f
'
{
int_id_to_str_id
(
entity_id
)
}
_
{
sym_id
}
'
]
=
chain_features
seq_length
=
chain_features
[
'seq_length'
]
chain_features
[
'asym_id'
]
=
(
chain_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_features
[
'sym_id'
]
=
(
sym_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_features
[
'entity_id'
]
=
(
entity_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_id
+=
1
return
new_all_chain_features
def
pad_msa
(
np_example
,
min_num_seq
):
np_example
=
dict
(
np_example
)
num_seq
=
np_example
[
'msa'
].
shape
[
0
]
if
num_seq
<
min_num_seq
:
for
feat
in
(
'msa'
,
'deletion_matrix'
,
'bert_mask'
,
'msa_mask'
):
np_example
[
feat
]
=
np
.
pad
(
np_example
[
feat
],
((
0
,
min_num_seq
-
num_seq
),
(
0
,
0
)))
np_example
[
'cluster_bias_mask'
]
=
np
.
pad
(
np_example
[
'cluster_bias_mask'
],
((
0
,
min_num_seq
-
num_seq
),))
return
np_example
class
DataPipeline
:
"""Assembles input features."""
def
__init__
(
...
...
@@ -478,6 +628,7 @@ class DataPipeline:
def
_parse_template_hits
(
self
,
alignment_dir
:
str
,
input_sequence
:
str
=
None
,
_alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
...
...
@@ -494,6 +645,12 @@ class DataPipeline:
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
elif
(
name
==
"hmmsearch_output.sto"
):
hits
=
parsers
.
parse_hmmsearch_sto
(
read_template
(
start
,
size
),
input_sequence
,
)
all_hits
[
name
]
=
hits
fp
.
close
()
else
:
...
...
@@ -505,6 +662,13 @@ class DataPipeline:
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
elif
(
f
==
"hmm_output.sto"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hmmsearch_sto
(
fp
.
read
(),
input_sequence
,
)
all_hits
[
f
]
=
hits
return
all_hits
...
...
@@ -676,3 +840,112 @@ class DataPipeline:
return
{
**
core_feats
,
**
template_features
,
**
msa_features
}
class
DataPipelineMultimer
:
"""Runs the alignment tools and assembles the input features."""
def
__init__
(
self
,
monomer_data_pipeline
:
DataPipeline
,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self
.
_monomer_data_pipeline
=
monomer_data_pipeline
def
_process_single_chain
(
self
,
chain_id
:
str
,
sequence
:
str
,
description
:
str
,
chain_alignment_dir
:
str
,
is_homomer_or_monomer
:
bool
)
->
FeatureDict
:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str
=
f
'>
{
chain_id
}
\n
{
sequence
}
\n
'
if
not
os
.
path
.
exists
(
chain_alignment_dir
):
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
fasta_path
=
chain_fasta_path
,
alignment_dir
=
chain_alignment_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if
not
is_homomer_or_monomer
:
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_fasta_path
,
chain_alignment_dir
)
chain_features
.
update
(
all_seq_msa_features
)
return
chain_features
def
_all_seq_msa_features
(
self
,
fasta_path
,
alignment_dir
):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_species_identifiers'
,
)
feats
=
{
f
'
{
k
}
_all_seq'
:
v
for
k
,
v
in
all_seq_features
.
items
()
if
k
in
valid_feats
}
return
feats
def
process_fasta
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
"""Creates features."""
with
open
(
fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
all_chain_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
for
desc
,
seq
in
zip
(
input_descs
,
input_seqs
):
if
seq
in
sequence_features
:
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
sequence_features
[
seq
]
)
continue
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
sequence
=
seq
,
description
=
desc
,
chain_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
desc
),
is_homomer_or_monomer
=
is_homomer_or_monomer
)
chain_features
=
convert_monomer_features
(
chain_features
,
chain_id
=
desc
)
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
fastfold/data/feature_processing_multimer.py
0 → 100644
View file @
444c548a
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from
typing
import
Iterable
,
MutableMapping
,
List
,
Mapping
from
fastfold.data
import
msa_pairing
from
fastfold.common
import
residue_constants
import
numpy
as
np
# TODO: Move this into the config
REQUIRED_FEATURES
=
frozenset
({
'aatype'
,
'all_atom_mask'
,
'all_atom_positions'
,
'all_chains_entity_ids'
,
'all_crops_all_chains_mask'
,
'all_crops_all_chains_positions'
,
'all_crops_all_chains_residue_ids'
,
'assembly_num_chains'
,
'asym_id'
,
'bert_mask'
,
'cluster_bias_mask'
,
'deletion_matrix'
,
'deletion_mean'
,
'entity_id'
,
'entity_mask'
,
'mem_peak'
,
'msa'
,
'msa_mask'
,
'num_alignments'
,
'num_templates'
,
'queue_size'
,
'residue_index'
,
'resolution'
,
'seq_length'
,
'seq_mask'
,
'sym_id'
,
'template_aatype'
,
'template_all_atom_mask'
,
'template_all_atom_positions'
})
MAX_TEMPLATES
=
4
MSA_CROP_SIZE
=
2048
def
_is_homomer_or_monomer
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]])
->
bool
:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains
=
len
(
np
.
unique
(
np
.
concatenate
(
[
np
.
unique
(
chain
[
'entity_id'
][
chain
[
'entity_id'
]
>
0
])
for
chain
in
chains
])))
return
num_unique_chains
==
1
def
pair_and_merge
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features
(
all_chain_features
)
np_chains_list
=
list
(
all_chain_features
.
values
())
pair_msa_sequences
=
not
_is_homomer_or_monomer
(
np_chains_list
)
if
pair_msa_sequences
:
np_chains_list
=
msa_pairing
.
create_paired_features
(
chains
=
np_chains_list
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
crop_chains
(
np_chains_list
,
msa_crop_size
=
MSA_CROP_SIZE
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
msa_pairing
.
merge_chain_features
(
np_chains_list
=
np_chains_list
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
process_final
(
np_example
)
return
np_example
def
crop_chains
(
chains_list
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains
=
[]
for
chain
in
chains_list
:
cropped_chain
=
_crop_single_chain
(
chain
,
msa_crop_size
=
msa_crop_size
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
max_templates
)
cropped_chains
.
append
(
cropped_chain
)
return
cropped_chains
def
_crop_single_chain
(
chain
:
Mapping
[
str
,
np
.
ndarray
],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Crops msa sequences to `msa_crop_size`."""
msa_size
=
chain
[
'num_alignments'
]
if
pair_msa_sequences
:
msa_size_all_seq
=
chain
[
'num_alignments_all_seq'
]
msa_crop_size_all_seq
=
np
.
minimum
(
msa_size_all_seq
,
msa_crop_size
//
2
)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq
=
chain
[
'msa_all_seq'
][:
msa_crop_size_all_seq
,
:]
num_non_gapped_pairs
=
np
.
sum
(
np
.
any
(
msa_all_seq
!=
msa_pairing
.
MSA_GAP_IDX
,
axis
=
1
))
num_non_gapped_pairs
=
np
.
minimum
(
num_non_gapped_pairs
,
msa_crop_size_all_seq
)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size
=
np
.
maximum
(
msa_crop_size
-
num_non_gapped_pairs
,
0
)
msa_crop_size
=
np
.
minimum
(
msa_size
,
max_msa_crop_size
)
else
:
msa_crop_size
=
np
.
minimum
(
msa_size
,
msa_crop_size
)
include_templates
=
'template_aatype'
in
chain
and
max_templates
if
include_templates
:
num_templates
=
chain
[
'template_aatype'
].
shape
[
0
]
templates_crop_size
=
np
.
minimum
(
num_templates
,
max_templates
)
for
k
in
chain
:
k_split
=
k
.
split
(
'_all_seq'
)[
0
]
if
k_split
in
msa_pairing
.
TEMPLATE_FEATURES
:
chain
[
k
]
=
chain
[
k
][:
templates_crop_size
,
:]
elif
k_split
in
msa_pairing
.
MSA_FEATURES
:
if
'_all_seq'
in
k
and
pair_msa_sequences
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size_all_seq
,
:]
else
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size
,
:]
chain
[
'num_alignments'
]
=
np
.
asarray
(
msa_crop_size
,
dtype
=
np
.
int32
)
if
include_templates
:
chain
[
'num_templates'
]
=
np
.
asarray
(
templates_crop_size
,
dtype
=
np
.
int32
)
if
pair_msa_sequences
:
chain
[
'num_alignments_all_seq'
]
=
np
.
asarray
(
msa_crop_size_all_seq
,
dtype
=
np
.
int32
)
return
chain
def
process_final
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example
=
_correct_msa_restypes
(
np_example
)
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_filter_features
(
np_example
)
return
np_example
def
_correct_msa_restypes
(
np_example
):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example
[
'msa'
]
=
np
.
take
(
new_order_list
,
np_example
[
'msa'
],
axis
=
0
)
np_example
[
'msa'
]
=
np_example
[
'msa'
].
astype
(
np
.
int32
)
return
np_example
def
_make_seq_mask
(
np_example
):
np_example
[
'seq_mask'
]
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
return
np_example
def
_make_msa_mask
(
np_example
):
"""Mask features are all ones, but will later be zero-padded."""
np_example
[
'msa_mask'
]
=
np
.
ones_like
(
np_example
[
'msa'
],
dtype
=
np
.
float32
)
seq_mask
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
np_example
[
'msa_mask'
]
*=
seq_mask
[
None
]
return
np_example
def
_filter_features
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Filters features of example to only those requested."""
return
{
k
:
v
for
(
k
,
v
)
in
np_example
.
items
()
if
k
in
REQUIRED_FEATURES
}
def
process_unmerged_features
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]
):
"""Postprocessing stage for per-chain features before merging."""
num_chains
=
len
(
all_chain_features
)
for
chain_features
in
all_chain_features
.
values
():
# Convert deletion matrices to float.
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
if
'deletion_matrix_int_all_seq'
in
chain_features
:
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_matrix'
],
axis
=
0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
chain_features
[
'aatype'
]]
chain_features
[
'all_atom_mask'
]
=
all_atom_mask
chain_features
[
'all_atom_positions'
]
=
np
.
zeros
(
list
(
all_atom_mask
.
shape
)
+
[
3
])
# Add assembly_num_chains.
chain_features
[
'assembly_num_chains'
]
=
np
.
asarray
(
num_chains
)
# Add entity_mask.
for
chain_features
in
all_chain_features
.
values
():
chain_features
[
'entity_mask'
]
=
(
chain_features
[
'entity_id'
]
!=
0
).
astype
(
np
.
int32
)
fastfold/data/msa_pairing.py
View file @
444c548a
...
...
@@ -23,7 +23,7 @@ import numpy as np
import
pandas
as
pd
import
scipy.linalg
from
open
fold.n
p
import
residue_constants
from
fast
fold.
commo
n
import
residue_constants
# TODO: This stuff should probably also be in a config
...
...
fastfold/data/parsers.py
View file @
444c548a
...
...
@@ -16,14 +16,43 @@
"""Functions for parsing various file formats."""
import
collections
import
dataclasses
import
itertools
import
re
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Msa
:
"""Class representing a parsed MSA file"""
sequences
:
Sequence
[
str
]
deletion_matrix
:
DeletionMatrix
descriptions
:
Optional
[
Sequence
[
str
]]
def
__post_init__
(
self
):
if
(
not
(
len
(
self
.
sequences
)
==
len
(
self
.
deletion_matrix
)
==
len
(
self
.
descriptions
)
)):
raise
ValueError
(
"All fields for an MSA must have the same length"
)
def
__len__
(
self
):
return
len
(
self
.
sequences
)
def
truncate
(
self
,
max_seqs
:
int
):
return
Msa
(
sequences
=
self
.
sequences
[:
max_seqs
],
deletion_matrix
=
self
.
deletion_matrix
[:
max_seqs
],
descriptions
=
self
.
descriptions
[:
max_seqs
],
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateHit
:
"""Class representing a template hit."""
...
...
@@ -31,7 +60,7 @@ class TemplateHit:
index
:
int
name
:
str
aligned_cols
:
int
sum_probs
:
float
sum_probs
:
Optional
[
float
]
query
:
str
hit_sequence
:
str
indices_query
:
List
[
int
]
...
...
@@ -172,7 +201,9 @@ def _convert_sto_seq_to_a3m(
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
,
remove_first_row_gaps
:
bool
=
True
,
)
->
str
:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions
=
{}
...
...
@@ -210,13 +241,19 @@ def convert_stockholm_to_a3m(
# Convert sto format to a3m line by line
a3m_sequences
=
{}
if
(
remove_first_row_gaps
):
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
for
seqname
,
sto_sequence
in
sequences
.
items
():
a3m_sequences
[
seqname
]
=
""
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
)
# Dots are optional in a3m format and are commonly removed.
out_sequence
=
sto_sequence
.
replace
(
'.'
,
''
)
if
(
remove_first_row_gaps
):
out_sequence
=
''
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
out_sequence
)
)
a3m_sequences
[
seqname
]
=
out_sequence
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
...
...
@@ -225,6 +262,124 @@ def convert_stockholm_to_a3m(
return
"
\n
"
.
join
(
fasta_chunks
)
+
"
\n
"
# Include terminating newline.
def
_keep_line
(
line
:
str
,
seqnames
:
Set
[
str
])
->
bool
:
"""Function to decide which lines to keep."""
if
not
line
.
strip
():
return
True
if
line
.
strip
()
==
'//'
:
# End tag
return
True
if
line
.
startswith
(
'# STOCKHOLM'
):
# Start tag
return
True
if
line
.
startswith
(
'#=GC RF'
):
# Reference Annotation Line
return
True
if
line
[:
4
]
==
'#=GS'
:
# Description lines - keep if sequence in list.
_
,
seqname
,
_
=
line
.
split
(
maxsplit
=
2
)
return
seqname
in
seqnames
elif
line
.
startswith
(
'#'
):
# Other markup - filter out
return
False
else
:
# Alignment data - keep if sequence in list.
seqname
=
line
.
partition
(
' '
)[
0
]
return
seqname
in
seqnames
def
truncate_stockholm_msa
(
stockholm_msa_path
:
str
,
max_sequences
:
int
)
->
str
:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames
=
set
()
filtered_lines
=
[]
with
open
(
stockholm_msa_path
)
as
f
:
for
line
in
f
:
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname
=
line
.
partition
(
' '
)[
0
]
seqnames
.
add
(
seqname
)
if
len
(
seqnames
)
>=
max_sequences
:
break
f
.
seek
(
0
)
for
line
in
f
:
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
''
.
join
(
filtered_lines
)
def
remove_empty_columns_from_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines
=
{}
unprocessed_lines
=
{}
for
i
,
line
in
enumerate
(
stockholm_msa
.
splitlines
()):
if
line
.
startswith
(
'#=GC RF'
):
reference_annotation_i
=
i
reference_annotation_line
=
line
# Reached the end of this chunk of the alignment. Process chunk.
_
,
_
,
first_alignment
=
line
.
rpartition
(
' '
)
mask
=
[]
for
j
in
range
(
len
(
first_alignment
)):
for
_
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
if
alignment
[
j
]
!=
'-'
:
mask
.
append
(
True
)
break
else
:
# Every row contained a hyphen - empty column.
mask
.
append
(
False
)
# Add reference annotation for processing with mask.
unprocessed_lines
[
reference_annotation_i
]
=
reference_annotation_line
if
not
any
(
mask
):
# All columns were empty. Output empty lines for chunk.
for
line_index
in
unprocessed_lines
:
processed_lines
[
line_index
]
=
''
else
:
for
line_index
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
processed_lines
[
line_index
]
=
f
'
{
prefix
}
{
masked_alignment
}
'
# Clear raw_alignments.
unprocessed_lines
=
{}
elif
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
unprocessed_lines
[
i
]
=
line
else
:
processed_lines
[
i
]
=
line
return
'
\n
'
.
join
((
processed_lines
[
i
]
for
i
in
range
(
len
(
processed_lines
))))
def
deduplicate_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict
=
collections
.
defaultdict
(
str
)
# First we must extract all sequences from the MSA.
for
line
in
stockholm_msa
.
splitlines
():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
line
=
line
.
strip
()
seqname
,
alignment
=
line
.
split
()
sequence_dict
[
seqname
]
+=
alignment
seen_sequences
=
set
()
seqnames
=
set
()
# First alignment is the query.
query_align
=
next
(
iter
(
sequence_dict
.
values
()))
mask
=
[
c
!=
'-'
for
c
in
query_align
]
# Mask is False for insertions.
for
seqname
,
alignment
in
sequence_dict
.
items
():
# Apply mask to remove all insertions from the string.
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
if
masked_alignment
in
seen_sequences
:
continue
else
:
seen_sequences
.
add
(
masked_alignment
)
seqnames
.
add
(
seqname
)
filtered_lines
=
[]
for
line
in
stockholm_msa
.
splitlines
():
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
'
\n
'
.
join
(
filtered_lines
)
+
'
\n
'
def
_get_hhr_line_regex_groups
(
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
...
...
@@ -278,7 +433,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"Could not parse section: %s. Expected this:
\n
%s to contain summary."
%
(
detailed_lines
,
detailed_lines
[
2
])
)
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
(
_
,
_
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
_
)
=
[
float
(
x
)
for
x
in
match
.
groups
()
]
...
...
@@ -386,3 +541,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name
=
fields
[
0
]
e_values
[
target_name
]
=
float
(
e_value
)
return
e_values
def
_get_indices
(
sequence
:
str
,
start
:
int
)
->
List
[
int
]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices
=
[]
counter
=
start
for
symbol
in
sequence
:
# Skip gaps but add a placeholder so that the alignment is preserved.
if
symbol
==
'-'
:
indices
.
append
(
-
1
)
# Skip deleted residues, but increase the counter.
elif
symbol
.
islower
():
counter
+=
1
# Normal aligned residue. Increase the counter and append to indices.
else
:
indices
.
append
(
counter
)
counter
+=
1
return
indices
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
HitMetadata
:
pdb_id
:
str
chain
:
str
start
:
int
end
:
int
length
:
int
text
:
str
def
_parse_hmmsearch_description
(
description
:
str
)
->
HitMetadata
:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match
=
re
.
match
(
r
'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$'
,
description
.
strip
())
if
not
match
:
raise
ValueError
(
f
'Could not parse description: "
{
description
}
".'
)
return
HitMetadata
(
pdb_id
=
match
[
1
],
chain
=
match
[
2
],
start
=
int
(
match
[
3
]),
end
=
int
(
match
[
4
]),
length
=
int
(
match
[
5
]),
text
=
match
[
6
]
)
def
parse_hmmsearch_a3m
(
query_sequence
:
str
,
a3m_string
:
str
,
skip_first
:
bool
=
True
)
->
Sequence
[
TemplateHit
]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m
=
list
(
zip
(
*
parse_fasta
(
a3m_string
)))
if
skip_first
:
parsed_a3m
=
parsed_a3m
[
1
:]
indices_query
=
_get_indices
(
query_sequence
,
start
=
0
)
hits
=
[]
for
i
,
(
hit_sequence
,
hit_description
)
in
enumerate
(
parsed_a3m
,
start
=
1
):
if
'mol:protein'
not
in
hit_description
:
continue
# Skip non-protein chains.
metadata
=
_parse_hmmsearch_description
(
hit_description
)
# Aligned columns are only the match states.
aligned_cols
=
sum
([
r
.
isupper
()
and
r
!=
'-'
for
r
in
hit_sequence
])
indices_hit
=
_get_indices
(
hit_sequence
,
start
=
metadata
.
start
-
1
)
hit
=
TemplateHit
(
index
=
i
,
name
=
f
'
{
metadata
.
pdb_id
}
_
{
metadata
.
chain
}
'
,
aligned_cols
=
aligned_cols
,
sum_probs
=
None
,
query
=
query_sequence
,
hit_sequence
=
hit_sequence
.
upper
(),
indices_query
=
indices_query
,
indices_hit
=
indices_hit
,
)
hits
.
append
(
hit
)
return
hits
def
parse_hmmsearch_sto
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
fastfold/data/templates.py
View file @
444c548a
...
...
@@ -1105,3 +1105,118 @@ class TemplateHitFeaturizer:
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
class
HmmsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
logging
.
info
(
"Searching for template for: %s"
,
query_sequence
)
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
already_seen
=
set
()
errors
=
[]
warnings
=
[]
# DISCREPANCY: This filtering scheme that saves time
filtered
=
[]
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
)
if
prefilter_result
.
error
:
errors
.
append
(
prefilter_result
.
error
)
if
prefilter_result
.
warning
:
warnings
.
append
(
prefilter_result
.
warning
)
if
prefilter_result
.
valid
:
filtered
.
append
(
hit
)
filtered
=
list
(
sorted
(
filtered
,
key
=
lambda
x
:
x
.
sum_probs
if
x
.
sum_probs
else
0.
,
reverse
=
True
)
)
idx
=
list
(
range
(
len
(
filtered
)))
if
(
self
.
_shuffle_top_k_prefiltered
):
stk
=
self
.
_shuffle_top_k_prefiltered
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
for
i
in
idx
:
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
break
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
if
result
.
error
:
errors
.
append
(
result
.
error
)
if
result
.
warning
:
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
logging
.
debug
(
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
.
name
,
result
.
error
,
result
.
warning
,
)
else
:
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
(
already_seen_key
in
already_seen
):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
already_seen
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
template_features
=
{
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
,
)
fastfold/data/tools/hmmbuild.py
0 → 100644
View file @
444c548a
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import
os
import
re
import
subprocess
from
absl
import
logging
from
fastfold.data.tools
import
utils
class
Hmmbuild
(
object
):
"""Python wrapper of the hmmbuild binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
singlemx
:
bool
=
False
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
singlemx
=
singlemx
def
build_profile_from_sto
(
self
,
sto
:
str
,
model_construction
=
'fast'
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return
self
.
_build_profile
(
sto
,
model_construction
=
model_construction
)
def
build_profile_from_a3m
(
self
,
a3m
:
str
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines
=
[]
for
line
in
a3m
.
splitlines
():
if
not
line
.
startswith
(
'>'
):
line
=
re
.
sub
(
'[a-z]+'
,
''
,
line
)
# Remove inserted residues.
lines
.
append
(
line
+
'
\n
'
)
msa
=
''
.
join
(
lines
)
return
self
.
_build_profile
(
msa
,
model_construction
=
'fast'
)
def
_build_profile
(
self
,
msa
:
str
,
model_construction
:
str
=
'fast'
)
->
str
:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if
model_construction
not
in
{
'hand'
,
'fast'
}:
raise
ValueError
(
f
'Invalid model_construction
{
model_construction
}
- only'
'hand and fast supported.'
)
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_query
=
os
.
path
.
join
(
query_tmp_dir
,
'query.msa'
)
output_hmm_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hmm'
)
with
open
(
input_query
,
'w'
)
as
f
:
f
.
write
(
msa
)
cmd
=
[
self
.
binary_path
]
# If adding flags, we have to do so before the output and input:
if
model_construction
==
'hand'
:
cmd
.
append
(
f
'--
{
model_construction
}
'
)
if
self
.
singlemx
:
cmd
.
append
(
'--singlemx'
)
cmd
.
extend
([
'--amino'
,
output_hmm_path
,
input_query
,
])
logging
.
info
(
'Launching subprocess %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'hmmbuild query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'hmmbuild stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'hmmbuild failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_hmm_path
,
encoding
=
'utf-8'
)
as
f
:
hmm
=
f
.
read
()
return
hmm
fastfold/data/tools/hmmsearch.py
View file @
444c548a
...
...
@@ -16,23 +16,23 @@
import
os
import
subprocess
import
logging
from
typing
import
Optional
,
Sequence
from
absl
import
logging
from
fastfold.data
import
parsers
from
fastfold.data.tools
import
hmmbuild
from
fastfold.utils
import
general_utils
as
utils
from
fastfold.data.tools
import
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
...
...
@@ -51,85 +51,73 @@ class Hmmsearch(object):
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
"--F1"
,
"0.1"
,
"--F2"
,
"0.1"
,
"--F3"
,
"0.1"
,
"--incE"
,
"100"
,
"-E"
,
"100"
,
"--domE"
,
"100"
,
"--incdomE"
,
"100"
,
]
flags
=
[
'--F1'
,
'0.1'
,
'--F2'
,
'0.1'
,
'--F3'
,
'0.1'
,
'--incE'
,
'100'
,
'-E'
,
'100'
,
'--domE'
,
'100'
,
'--incdomE'
,
'100'
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
"
Could not find hmmsearch database %s
"
,
database_path
)
raise
ValueError
(
f
"
Could not find hmmsearch database
{
database_path
}
"
)
logging
.
error
(
'
Could not find hmmsearch database %s
'
,
database_path
)
raise
ValueError
(
f
'
Could not find hmmsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
"
sto
"
return
'
sto
'
@
property
def
input_format
(
self
)
->
str
:
return
"
sto
"
return
'
sto
'
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
"hand"
msa_sto
,
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
query.hmm
"
)
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
query.hmm
'
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
"
hmm_output.sto
"
)
with
open
(
hmm_input_path
,
"w"
)
as
f
:
out_path
=
os
.
path
.
join
(
output_dir
,
'
hmm_output.sto
'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
"--noali"
,
# Don't include the alignment in stdout.
"--cpu"
,
"8"
,
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
(
[
"-A"
,
out_path
,
cmd
.
extend
([
'-A'
,
out_path
,
hmm_input_path
,
self
.
database_path
,
]
)
])
logging
.
info
(
"
Launching sub-process %s
"
,
cmd
)
logging
.
info
(
'
Launching sub-process %s
'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
"hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query"
):
f
'hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
"hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
.
decode
(
"utf-8"
))
)
'hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
...
...
@@ -138,7 +126,8 @@ class Hmmsearch(object):
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits
=
parsers
.
parse_hmmsearch_sto
(
...
...
inference.py
View file @
444c548a
...
...
@@ -32,6 +32,7 @@ from fastfold.common import protein, residue_constants
from
fastfold.config
import
model_config
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data.tools
import
hhsearch
,
hmmsearch
from
fastfold.workflow.template
import
FastFoldDataWorkFlow
from
fastfold.utils
import
inject_fastnn
from
fastfold.data.parsers
import
parse_fasta
...
...
@@ -111,8 +112,41 @@ def inference_model(rank, world_size, result_q, batch, args):
def
main
(
args
):
if
args
.
model_preset
==
"multimer"
:
inference_multimer_model
(
args
)
else
:
inference_monomer_model
(
args
)
def
inference_multimer_model
(
args
):
print
(
"running in multimer mode..."
)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
predict_max_templates
=
4
if
not
args
.
use_precomputed_alignments
:
template_searcher
=
hmmsearch
.
Hmmsearch
(
binary_path
=
args
.
hmmsearch_binary_path
,
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
database_path
=
args
.
pdb_seqres_database_path
,
)
else
:
template_searcher
=
None
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
predict_max_templates
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
,
)
def
inference_monomer_model
(
args
):
print
(
"running in monomer mode..."
)
config
=
model_config
(
args
.
model_name
)
global_is_multimer
=
True
if
args
.
model_preset
==
"multimer"
else
False
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
...
...
@@ -120,7 +154,8 @@ def main(args):
max_hits
=
config
.
data
.
predict
.
max_templates
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
use_small_bfd
=
args
.
preset
==
'reduced_dbs'
# (args.bfd_database_path is None)
if
use_small_bfd
:
...
...
@@ -158,10 +193,7 @@ def main(args):
print
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
global_is_multimer
:
print
(
"running in multimer mode..."
)
feature_dict
=
pickle
.
load
(
open
(
"/home/lcmql/data/features_pdb1o5d.pkl"
,
"rb"
))
else
:
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
...
...
@@ -206,7 +238,6 @@ def main(args):
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
is_multimer
=
global_is_multimer
,
)
batch
=
processed_feature_dict
...
...
@@ -251,7 +282,6 @@ def main(args):
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment