Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3279b28d
Commit
3279b28d
authored
Feb 03, 2022
by
Gustaf Ahdritz
Browse files
Remove alignment index
parent
55bf27d4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
110 deletions
+35
-110
openfold/data/data_modules.py
openfold/data/data_modules.py
+4
-24
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+31
-83
train_openfold.py
train_openfold.py
+0
-3
No files found.
openfold/data/data_modules.py
View file @
3279b28d
...
...
@@ -37,7 +37,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
_output_raw
:
bool
=
False
,
_alignment_index
:
Optional
[
Any
]
=
None
):
"""
Args:
...
...
@@ -84,7 +83,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
mode
=
mode
self
.
_output_raw
=
_output_raw
self
.
_alignment_index
=
_alignment_index
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
...
...
@@ -96,9 +94,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if
(
_alignment_index
is
not
None
):
self
.
_chain_ids
=
list
(
_alignment_index
.
keys
())
elif
(
mapping_path
is
None
):
if
(
mapping_path
is
None
):
self
.
_chain_ids
=
list
(
os
.
listdir
(
alignment_dir
))
else
:
with
open
(
mapping_path
,
"r"
)
as
f
:
...
...
@@ -125,7 +121,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if
(
not
self
.
_output_raw
):
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
):
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
...
...
@@ -144,7 +140,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
)
return
data
...
...
@@ -159,11 +154,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name
=
self
.
idx_to_chain_id
(
idx
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
_alignment_index
=
None
if
(
self
.
_alignment_index
is
not
None
):
alignment_dir
=
self
.
alignment_dir
_alignment_index
=
self
.
_alignment_index
[
name
]
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
spl
=
name
.
rsplit
(
'_'
,
1
)
if
(
len
(
spl
)
==
2
):
...
...
@@ -175,11 +165,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
if
(
os
.
path
.
exists
(
path
+
".cif"
)):
data
=
self
.
_parse_mmcif
(
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
,
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
,
)
elif
(
os
.
path
.
exists
(
path
+
".core"
)):
data
=
self
.
data_pipeline
.
process_core
(
path
+
".core"
,
alignment_dir
,
_alignment_index
,
path
+
".core"
,
alignment_dir
,
)
elif
(
os
.
path
.
exists
(
path
+
".pdb"
)):
data
=
self
.
data_pipeline
.
process_pdb
(
...
...
@@ -187,7 +177,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir
=
alignment_dir
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
,
)
else
:
raise
ValueError
(
"Invalid file type"
)
...
...
@@ -196,7 +185,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
=
self
.
data_pipeline
.
process_fasta
(
fasta_path
=
path
,
alignment_dir
=
alignment_dir
,
_alignment_index
=
_alignment_index
,
)
if
(
self
.
_output_raw
):
...
...
@@ -486,7 +474,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
...
...
@@ -538,12 +525,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
...
...
@@ -568,7 +549,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
_output_raw
=
True
,
_alignment_index
=
self
.
_alignment_index
,
)
distillation_dataset
=
None
...
...
openfold/data/data_pipeline.py
View file @
3279b28d
...
...
@@ -422,38 +422,8 @@ class DataPipeline:
def
_parse_msa_data
(
self
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
msa_data
=
{}
if
(
_alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
"rb"
)
def
read_msa
(
start
,
size
):
fp
.
seek
(
start
)
msa
=
fp
.
read
(
size
).
decode
(
"utf-8"
)
return
msa
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".a3m"
):
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
continue
msa_data
[
name
]
=
data
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
...
...
@@ -478,25 +448,8 @@ class DataPipeline:
def
_parse_template_hits
(
self
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
if
(
_alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
'rb'
)
def
read_template
(
start
,
size
):
fp
.
seek
(
start
)
return
fp
.
read
(
size
).
decode
(
"utf-8"
)
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
...
...
@@ -512,9 +465,8 @@ class DataPipeline:
self
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
)
if
(
len
(
msa_data
)
==
0
):
if
(
input_sequence
is
None
):
...
...
@@ -544,7 +496,6 @@ class DataPipeline:
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
...
...
@@ -558,7 +509,7 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -571,7 +522,7 @@ class DataPipeline:
num_res
=
num_res
,
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
return
{
**
sequence_features
,
...
...
@@ -584,7 +535,6 @@ class DataPipeline:
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a specific chain in an mmCIF object.
...
...
@@ -602,7 +552,7 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -610,7 +560,7 @@ class DataPipeline:
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
])
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
return
{
**
mmcif_feats
,
**
template_features
,
**
msa_features
}
...
...
@@ -620,7 +570,6 @@ class DataPipeline:
alignment_dir
:
str
,
is_distillation
:
bool
=
True
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a protein in a PDB file.
...
...
@@ -637,14 +586,14 @@ class DataPipeline:
is_distillation
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
self
.
template_featurizer
,
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
return
{
**
pdb_feats
,
**
template_features
,
**
msa_features
}
...
...
@@ -652,7 +601,6 @@ class DataPipeline:
self
,
core_path
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a protein in a ProteinNet .core file.
...
...
@@ -665,7 +613,7 @@ class DataPipeline:
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
train_openfold.py
View file @
3279b28d
...
...
@@ -370,9 +370,6 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment