Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b5f885d0
Commit
b5f885d0
authored
Oct 26, 2021
by
Gustaf Ahdritz
Browse files
Add ProteinNet support to parser
parent
94ab346e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
137 additions
and
91 deletions
+137
-91
README.md
README.md
+3
-3
openfold/data/data_modules.py
openfold/data/data_modules.py
+2
-6
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+67
-80
scripts/prepare_proteinnet_msas.py
scripts/prepare_proteinnet_msas.py
+64
-0
train_openfold.py
train_openfold.py
+1
-2
No files found.
README.md
View file @
b5f885d0
...
...
@@ -17,9 +17,9 @@ Try it out with our [Colab notebook](https://colab.research.google.com/github/aq
(not yet visible from Colab because the repo is still private).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with
or without
[
DeepSpeed
](
https://github.com/microsoft/deepspeed
)
and with
mixed
precision.
`bfloat16`
training is not currently supported, but will be
in the
future.
with
[
DeepSpeed
](
https://github.com/microsoft/deepspeed
)
and with
mixed
precision.
`bfloat16`
training is not currently supported, but will be
in the
future.
## Installation (Linux)
...
...
openfold/data/data_modules.py
View file @
b5f885d0
...
...
@@ -46,11 +46,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
config:
A dataset config object. See openfold.config
mapping_path:
...
...
@@ -97,7 +94,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
use_small_bfd
=
use_small_bfd
,
)
if
(
not
self
.
output_raw
):
...
...
openfold/data/data_pipeline.py
View file @
b5f885d0
...
...
@@ -260,47 +260,65 @@ class DataPipeline:
def
__init__
(
self
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
):
self
.
template_featurizer
=
template_featurizer
self
.
use_small_bfd
=
use_small_bfd
def
_parse_
alignment_output
(
def
_parse_
msa_data
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniref90_hits.a3m"
)
with
open
(
uniref90_out_path
,
"r"
)
as
f
:
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
"mgnify_hits.a3m"
)
with
open
(
mgnify_out_path
,
"r"
)
as
f
:
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
"pdb70_hits.hhr"
)
with
open
(
pdb70_out_path
,
"r"
)
as
f
:
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
())
if
self
.
use_small_bfd
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"small_bfd_hits.sto"
)
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
f
.
read
()
msa_data
=
{}
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".a3m"
):
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
fp
.
read
())
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
fp
.
read
()
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.a3m"
)
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
continue
return
{
"uniref90_msa"
:
uniref90_msa
,
"uniref90_deletion_matrix"
:
uniref90_deletion_matrix
,
"mgnify_msa"
:
mgnify_msa
,
"mgnify_deletion_matrix"
:
mgnify_deletion_matrix
,
"hhsearch_hits"
:
hhsearch_hits
,
"bfd_msa"
:
bfd_msa
,
"bfd_deletion_matrix"
:
bfd_deletion_matrix
,
}
msa_data
[
f
]
=
data
return
msa_data
def
_parse_template_hits
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".hhr"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
return
all_hits
def
_process_msa_feats
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
)
msas
,
deletion_matrices
=
zip
(
*
[
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
msa_features
=
make_msa_features
(
msas
=
msas
,
deletion_matrices
=
deletion_matrices
,
)
return
msa_features
def
process_fasta
(
self
,
...
...
@@ -319,13 +337,13 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
alignmen
ts
=
self
.
_parse_
alignment_output
(
alignment_dir
)
hi
ts
=
self
.
_parse_
template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
None
,
hits
=
alignments
[
"hhsearch_hits"
]
,
hits
=
hits_cat
,
)
sequence_features
=
make_sequence_features
(
...
...
@@ -334,18 +352,8 @@ class DataPipeline:
num_res
=
num_res
,
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
)
return
{
**
sequence_features
,
**
msa_features
,
...
...
@@ -373,28 +381,18 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
print
(
len
(
hits_cat
))
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
"hhsearch_hits"
]
,
hits
=
hits_cat
,
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
)
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
...
...
@@ -413,26 +411,15 @@ class DataPipeline:
pdb_feats
=
make_pdb_features
(
protein_object
)
alignmen
ts
=
self
.
_parse_
alignment_output
(
alignment_dir
)
hi
ts
=
self
.
_parse_
template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
protein_object
.
aatyp
e
,
query_sequence
=
input_sequenc
e
,
query_pdb_code
=
None
,
query_release_date
=
None
,
hits
=
alignments
[
"hhsearch_hits"
]
,
hits
=
hits_cat
,
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
)
return
{
**
pdb_feats
,
**
templates_result
.
features
,
**
msa_features
}
scripts/prepare_proteinnet_msas.py
0 → 100644
View file @
b5f885d0
import
argparse
import
logging
import
os
import
shutil
def
main
(
args
):
count
=
0
max_count
=
args
.
max_count
if
args
.
max_count
is
not
None
else
-
1
msas
=
sorted
(
f
for
f
in
os
.
listdir
(
args
.
msa_dir
))
mmcifs
=
sorted
(
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
))
mmcif_idx
=
0
for
f
in
msas
:
if
(
count
==
max_count
):
break
path
=
os
.
path
.
join
(
args
.
msa_dir
,
f
)
name
=
os
.
path
.
splitext
(
f
)[
0
]
spl
=
name
.
upper
().
split
(
'_'
)
if
(
len
(
spl
)
!=
3
):
continue
pdb_id
,
_
,
chain_id
=
spl
while
pdb_id
>
os
.
path
.
splitext
(
mmcifs
[
mmcif_idx
])[
0
].
upper
():
mmcif_idx
+=
1
# Only consider files with matching mmCIF files
if
(
pdb_id
==
os
.
path
.
splitext
(
mmcifs
[
mmcif_idx
])[
0
].
upper
()):
dirname
=
os
.
path
.
join
(
args
.
out_dir
,
'_'
.
join
([
pdb_id
,
chain_id
]))
os
.
makedirs
(
dirname
,
exist_ok
=
True
)
dest
=
os
.
path
.
join
(
dirname
,
f
)
if
(
args
.
copy
):
shutil
.
copyfile
(
path
,
dest
)
else
:
os
.
rename
(
path
,
dest
)
count
+=
1
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"msa_dir"
,
type
=
str
,
help
=
"Directory containing ProteinNet MSAs"
)
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
,
help
=
"Directory containing PDB mmCIFs"
)
parser
.
add_argument
(
"out_dir"
,
type
=
str
,
help
=
"Directory to which output should be saved"
)
parser
.
add_argument
(
"--copy"
,
type
=
bool
,
default
=
True
,
help
=
"Whether to copy the MSAs to out_dir rather than moving them"
)
parser
.
add_argument
(
"--max_count"
,
type
=
int
,
default
=
None
,
help
=
"A bound on the number of MSAs to process"
)
args
=
parser
.
parse_args
()
main
(
args
)
train_openfold.py
View file @
b5f885d0
...
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
os
#os.environ["CUDA_VISIBLE_DEVICES"] = "
6
"
#os.environ["CUDA_VISIBLE_DEVICES"] = "
5
"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
...
...
@@ -14,7 +14,6 @@ import time
import
numpy
as
np
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.plugins
import
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
import
torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment