Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4f38c826
Commit
4f38c826
authored
Dec 11, 2023
by
Christina Floristean
Browse files
Added alignment indexing to multimer pipeline and fixed jackhmmer query return type
parent
bcabb8e3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
40 deletions
+66
-40
openfold/data/data_modules.py
openfold/data/data_modules.py
+6
-1
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+57
-36
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+3
-3
No files found.
openfold/data/data_modules.py
View file @
4f38c826
...
@@ -451,7 +451,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -451,7 +451,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
alignment_index
=
None
alignment_index
=
None
if
self
.
alignment_index
is
not
None
:
alignment_index
=
{
k
:
v
for
k
,
v
in
self
.
alignment_index
.
items
()
if
f
'
{
mmcif_id
}
_'
in
k
}
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
...
@@ -476,7 +480,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -476,7 +480,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
.fasta"
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
.fasta"
)
data
=
self
.
data_pipeline
.
process_fasta
(
data
=
self
.
data_pipeline
.
process_fasta
(
fasta_path
=
path
,
fasta_path
=
path
,
alignment_dir
=
self
.
alignment_dir
alignment_dir
=
self
.
alignment_dir
,
alignment_index
=
alignment_index
)
)
if
self
.
_output_raw
:
if
self
.
_output_raw
:
...
...
openfold/data/data_pipeline.py
View file @
4f38c826
...
@@ -794,7 +794,7 @@ class DataPipeline:
...
@@ -794,7 +794,7 @@ class DataPipeline:
def
_get_msas
(
self
,
def
_get_msas
(
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
Any
]
=
None
,
):
):
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
alignment_index
)
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
alignment_index
)
if
(
len
(
msa_data
)
==
0
):
if
(
len
(
msa_data
)
==
0
):
...
@@ -814,7 +814,7 @@ class DataPipeline:
...
@@ -814,7 +814,7 @@ class DataPipeline:
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msas
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
...
@@ -846,7 +846,7 @@ class DataPipeline:
...
@@ -846,7 +846,7 @@ class DataPipeline:
self
,
self
,
fasta_path
:
str
,
fasta_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
Any
]
=
None
,
seqemb_mode
:
bool
=
False
,
seqemb_mode
:
bool
=
False
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
"""Assembles features for a single sequence in a FASTA file"""
...
@@ -899,7 +899,7 @@ class DataPipeline:
...
@@ -899,7 +899,7 @@ class DataPipeline:
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
chain_id
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
Any
]
=
None
,
seqemb_mode
:
bool
=
False
,
seqemb_mode
:
bool
=
False
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
...
@@ -946,7 +946,7 @@ class DataPipeline:
...
@@ -946,7 +946,7 @@ class DataPipeline:
is_distillation
:
bool
=
True
,
is_distillation
:
bool
=
True
,
chain_id
:
Optional
[
str
]
=
None
,
chain_id
:
Optional
[
str
]
=
None
,
_structure_index
:
Optional
[
str
]
=
None
,
_structure_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
Any
]
=
None
,
seqemb_mode
:
bool
=
False
,
seqemb_mode
:
bool
=
False
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
...
@@ -1000,7 +1000,7 @@ class DataPipeline:
...
@@ -1000,7 +1000,7 @@ class DataPipeline:
self
,
self
,
core_path
:
str
,
core_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
Any
]
=
None
,
seqemb_mode
:
bool
=
False
,
seqemb_mode
:
bool
=
False
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
...
@@ -1156,30 +1156,49 @@ class DataPipelineMultimer:
...
@@ -1156,30 +1156,49 @@ class DataPipelineMultimer:
sequence
:
str
,
sequence
:
str
,
description
:
str
,
description
:
str
,
chain_alignment_dir
:
str
,
chain_alignment_dir
:
str
,
chain_alignment_index
:
Optional
[
Any
],
is_homomer_or_monomer
:
bool
is_homomer_or_monomer
:
bool
)
->
FeatureDict
:
)
->
FeatureDict
:
"""Runs the monomer pipeline on a single chain."""
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str
=
f
'>
{
chain_id
}
\n
{
sequence
}
\n
'
chain_fasta_str
=
f
'>
{
chain_id
}
\n
{
sequence
}
\n
'
if
not
os
.
path
.
exists
(
chain_alignment_dir
):
if
chain_alignment_index
is
None
and
not
os
.
path
.
exists
(
chain_alignment_dir
):
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
fasta_path
=
chain_fasta_path
,
fasta_path
=
chain_fasta_path
,
alignment_dir
=
chain_alignment_dir
alignment_dir
=
chain_alignment_dir
,
alignment_index
=
chain_alignment_index
)
)
# We only construct the pairing features if there are 2 or more unique
# We only construct the pairing features if there are 2 or more unique
# sequences.
# sequences.
if
not
is_homomer_or_monomer
:
if
not
is_homomer_or_monomer
:
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_alignment_dir
chain_alignment_dir
,
chain_alignment_index
)
)
chain_features
.
update
(
all_seq_msa_features
)
chain_features
.
update
(
all_seq_msa_features
)
return
chain_features
return
chain_features
@
staticmethod
@
staticmethod
def
_all_seq_msa_features
(
alignment_dir
):
def
_all_seq_msa_features
(
alignment_dir
,
alignment_index
):
"""Get MSA features for unclustered uniprot, for pairing."""
"""Get MSA features for unclustered uniprot, for pairing."""
if
alignment_index
is
not
None
:
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
alignment_index
[
"db"
]),
"rb"
)
def
read_msa
(
start
,
size
):
fp
.
seek
(
start
)
msa
=
fp
.
read
(
size
).
decode
(
"utf-8"
)
return
msa
start
,
size
=
next
(
iter
((
start
,
size
)
for
name
,
start
,
size
in
alignment_index
[
"files"
]
if
name
==
'uniprot_hits.sto'
))
msa
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
))
fp
.
close
()
else
:
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
if
not
os
.
path
.
exists
(
uniprot_msa_path
):
if
not
os
.
path
.
exists
(
uniprot_msa_path
):
chain_id
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
alignment_dir
))
chain_id
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
alignment_dir
))
...
@@ -1189,6 +1208,7 @@ class DataPipelineMultimer:
...
@@ -1189,6 +1208,7 @@ class DataPipelineMultimer:
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
([
msa
])
all_seq_features
=
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_species_identifiers'
,
'msa_species_identifiers'
,
...
@@ -1202,15 +1222,14 @@ class DataPipelineMultimer:
...
@@ -1202,15 +1222,14 @@ class DataPipelineMultimer:
def
process_fasta
(
self
,
def
process_fasta
(
self
,
fasta_path
:
str
,
fasta_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
alignment_index
:
Optional
[
Any
]
=
None
)
->
FeatureDict
:
)
->
FeatureDict
:
"""Creates features."""
"""Creates features."""
with
open
(
fasta_path
)
as
f
:
with
open
(
fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
all_chain_features
=
{}
all_chain_features
=
{}
sequence_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
...
@@ -1221,16 +1240,22 @@ class DataPipelineMultimer:
...
@@ -1221,16 +1240,22 @@ class DataPipelineMultimer:
)
)
continue
continue
if
alignment_index
is
not
None
:
chain_alignment_index
=
alignment_index
.
get
(
desc
)
chain_alignment_dir
=
alignment_dir
else
:
chain_alignment_index
=
None
chain_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
desc
)
chain_features
=
self
.
_process_single_chain
(
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
chain_id
=
desc
,
sequence
=
seq
,
sequence
=
seq
,
description
=
desc
,
description
=
desc
,
chain_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
desc
),
chain_alignment_dir
=
chain_alignment_dir
,
chain_alignment_index
=
chain_alignment_index
,
is_homomer_or_monomer
=
is_homomer_or_monomer
is_homomer_or_monomer
=
is_homomer_or_monomer
)
)
chain_features
=
convert_monomer_features
(
chain_features
=
convert_monomer_features
(
chain_features
,
chain_features
,
chain_id
=
desc
chain_id
=
desc
...
@@ -1238,15 +1263,12 @@ class DataPipelineMultimer:
...
@@ -1238,15 +1263,12 @@ class DataPipelineMultimer:
all_chain_features
[
desc
]
=
chain_features
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
all_chain_features
=
all_chain_features
,
)
)
# Pad MSA to avoid zero-sized extra_msa.
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
np_example
=
pad_msa
(
np_example
,
512
)
...
@@ -1279,55 +1301,54 @@ class DataPipelineMultimer:
...
@@ -1279,55 +1301,54 @@ class DataPipelineMultimer:
self
,
self
,
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
alignment_dir
:
str
,
alignment_index
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
Any
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
all_chain_features
=
{}
all_chain_features
=
{}
sequence_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
list
(
mmcif
.
chain_to_seqres
.
values
())))
==
1
is_homomer_or_monomer
=
len
(
set
(
list
(
mmcif
.
chain_to_seqres
.
values
())))
==
1
for
chain_id
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
for
chain_id
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
desc
=
"_"
.
join
([
mmcif
.
file_id
,
chain_id
])
desc
=
"_"
.
join
([
mmcif
.
file_id
,
chain_id
])
if
seq
in
sequence_features
:
if
seq
in
sequence_features
:
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
sequence_features
[
seq
]
sequence_features
[
seq
]
)
)
continue
continue
if
alignment_index
is
not
None
:
chain_alignment_index
=
alignment_index
.
get
(
desc
)
chain_alignment_dir
=
alignment_dir
else
:
chain_alignment_index
=
None
chain_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
desc
)
chain_features
=
self
.
_process_single_chain
(
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
chain_id
=
desc
,
sequence
=
seq
,
sequence
=
seq
,
description
=
desc
,
description
=
desc
,
chain_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
desc
),
chain_alignment_dir
=
chain_alignment_dir
,
chain_alignment_index
=
chain_alignment_index
,
is_homomer_or_monomer
=
is_homomer_or_monomer
is_homomer_or_monomer
=
is_homomer_or_monomer
)
)
chain_features
=
convert_monomer_features
(
chain_features
=
convert_monomer_features
(
chain_features
,
chain_features
,
chain_id
=
desc
chain_id
=
desc
)
)
mmcif_feats
=
self
.
get_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
self
.
get_mmcif_features
(
mmcif
,
chain_id
)
chain_features
.
update
(
mmcif_feats
)
chain_features
.
update
(
mmcif_feats
)
all_chain_features
[
desc
]
=
chain_features
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
all_chain_features
=
all_chain_features
,
)
)
# Pad MSA to avoid zero-sized extra_msa.
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
return
np_example
\ No newline at end of file
openfold/data/tools/jackhmmer.py
View file @
4f38c826
...
@@ -190,11 +190,11 @@ class Jackhmmer:
...
@@ -190,11 +190,11 @@ class Jackhmmer:
def
query
(
self
,
def
query
(
self
,
input_fasta_path
:
str
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
)
->
Sequence
[
Sequence
[
Mapping
[
str
,
Any
]]
]
:
return
self
.
query_multiple
([
input_fasta_path
],
max_sequences
)
[
0
]
return
self
.
query_multiple
([
input_fasta_path
],
max_sequences
)
def
query_multiple
(
self
,
def
query_multiple
(
self
,
input_fasta_paths
:
str
,
input_fasta_paths
:
Sequence
[
str
]
,
max_sequences
:
Optional
[
int
]
=
None
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Sequence
[
Mapping
[
str
,
Any
]]]:
)
->
Sequence
[
Sequence
[
Mapping
[
str
,
Any
]]]:
"""Queries the database using Jackhmmer."""
"""Queries the database using Jackhmmer."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment