Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b5f885d0
Commit
b5f885d0
authored
Oct 26, 2021
by
Gustaf Ahdritz
Browse files
Add ProteinNet support to parser
parent
94ab346e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
137 additions
and
91 deletions
+137
-91
README.md
README.md
+3
-3
openfold/data/data_modules.py
openfold/data/data_modules.py
+2
-6
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+67
-80
scripts/prepare_proteinnet_msas.py
scripts/prepare_proteinnet_msas.py
+64
-0
train_openfold.py
train_openfold.py
+1
-2
No files found.
README.md
View file @
b5f885d0
...
@@ -17,9 +17,9 @@ Try it out with our [Colab notebook](https://colab.research.google.com/github/aq
...
@@ -17,9 +17,9 @@ Try it out with our [Colab notebook](https://colab.research.google.com/github/aq
(not yet visible from Colab because the repo is still private).
(not yet visible from Colab because the repo is still private).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with
or without
[
DeepSpeed
](
https://github.com/microsoft/deepspeed
)
and with
with
[
DeepSpeed
](
https://github.com/microsoft/deepspeed
)
and with
mixed
mixed
precision.
`bfloat16`
training is not currently supported, but will be
precision.
`bfloat16`
training is not currently supported, but will be
in the
in the
future.
future.
## Installation (Linux)
## Installation (Linux)
...
...
openfold/data/data_modules.py
View file @
b5f885d0
...
@@ -46,11 +46,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -46,11 +46,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
output by an AlignmentRunner
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
files.
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
config:
config:
A dataset config object. See openfold.config
A dataset config object. See openfold.config
mapping_path:
mapping_path:
...
@@ -97,7 +94,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -97,7 +94,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
use_small_bfd
=
use_small_bfd
,
)
)
if
(
not
self
.
output_raw
):
if
(
not
self
.
output_raw
):
...
...
openfold/data/data_pipeline.py
View file @
b5f885d0
...
@@ -260,47 +260,65 @@ class DataPipeline:
...
@@ -260,47 +260,65 @@ class DataPipeline:
def
__init__
(
def
__init__
(
self
,
self
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
):
):
self
.
template_featurizer
=
template_featurizer
self
.
template_featurizer
=
template_featurizer
self
.
use_small_bfd
=
use_small_bfd
def
_parse_
alignment_output
(
def
_parse_
msa_data
(
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniref90_hits.a3m"
)
msa_data
=
{}
with
open
(
uniref90_out_path
,
"r"
)
as
f
:
for
f
in
os
.
listdir
(
alignment_dir
):
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
"mgnify_hits.a3m"
)
with
open
(
mgnify_out_path
,
"r"
)
as
f
:
if
(
ext
==
".a3m"
):
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
fp
.
read
())
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
"pdb70_hits.hhr"
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
with
open
(
pdb70_out_path
,
"r"
)
as
f
:
elif
(
ext
==
".sto"
):
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
())
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
if
self
.
use_small_bfd
:
fp
.
read
()
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"small_bfd_hits.sto"
)
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
f
.
read
()
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
else
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.a3m"
)
continue
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
return
{
msa_data
[
f
]
=
data
"uniref90_msa"
:
uniref90_msa
,
"uniref90_deletion_matrix"
:
uniref90_deletion_matrix
,
return
msa_data
"mgnify_msa"
:
mgnify_msa
,
"mgnify_deletion_matrix"
:
mgnify_deletion_matrix
,
def
_parse_template_hits
(
"hhsearch_hits"
:
hhsearch_hits
,
self
,
"bfd_msa"
:
bfd_msa
,
alignment_dir
:
str
,
"bfd_deletion_matrix"
:
bfd_deletion_matrix
,
)
->
Mapping
[
str
,
Any
]:
}
all_hits
=
{}
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".hhr"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
return
all_hits
def
_process_msa_feats
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
)
msas
,
deletion_matrices
=
zip
(
*
[
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
msa_features
=
make_msa_features
(
msas
=
msas
,
deletion_matrices
=
deletion_matrices
,
)
return
msa_features
def
process_fasta
(
def
process_fasta
(
self
,
self
,
...
@@ -319,13 +337,13 @@ class DataPipeline:
...
@@ -319,13 +337,13 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
alignmen
ts
=
self
.
_parse_
alignment_output
(
alignment_dir
)
hi
ts
=
self
.
_parse_
template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
templates_result
=
self
.
template_featurizer
.
get_templates
(
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_pdb_code
=
None
,
query_release_date
=
None
,
query_release_date
=
None
,
hits
=
alignments
[
"hhsearch_hits"
]
,
hits
=
hits_cat
,
)
)
sequence_features
=
make_sequence_features
(
sequence_features
=
make_sequence_features
(
...
@@ -334,18 +352,8 @@ class DataPipeline:
...
@@ -334,18 +352,8 @@ class DataPipeline:
num_res
=
num_res
,
num_res
=
num_res
,
)
)
msa_features
=
make_msa_features
(
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
)
msas
=
(
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
)
return
{
return
{
**
sequence_features
,
**
sequence_features
,
**
msa_features
,
**
msa_features
,
...
@@ -373,28 +381,18 @@ class DataPipeline:
...
@@ -373,28 +381,18 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
print
(
len
(
hits_cat
))
templates_result
=
self
.
template_featurizer
.
get_templates
(
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
"hhsearch_hits"
]
,
hits
=
hits_cat
,
)
)
msa_features
=
make_msa_features
(
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
)
msas
=
(
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
)
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
...
@@ -413,26 +411,15 @@ class DataPipeline:
...
@@ -413,26 +411,15 @@ class DataPipeline:
pdb_feats
=
make_pdb_features
(
protein_object
)
pdb_feats
=
make_pdb_features
(
protein_object
)
alignmen
ts
=
self
.
_parse_
alignment_output
(
alignment_dir
)
hi
ts
=
self
.
_parse_
template_hits
(
alignment_dir
)
hits_cat
=
sum
(
hits
.
values
(),
[])
templates_result
=
self
.
template_featurizer
.
get_templates
(
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
protein_object
.
aatyp
e
,
query_sequence
=
input_sequenc
e
,
query_pdb_code
=
None
,
query_pdb_code
=
None
,
query_release_date
=
None
,
query_release_date
=
None
,
hits
=
alignments
[
"hhsearch_hits"
]
,
hits
=
hits_cat
,
)
)
msa_features
=
make_msa_features
(
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
)
msas
=
(
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
)
return
{
**
pdb_feats
,
**
templates_result
.
features
,
**
msa_features
}
return
{
**
pdb_feats
,
**
templates_result
.
features
,
**
msa_features
}
scripts/prepare_proteinnet_msas.py
0 → 100644
View file @
b5f885d0
import
argparse
import
logging
import
os
import
shutil
def
main
(
args
):
count
=
0
max_count
=
args
.
max_count
if
args
.
max_count
is
not
None
else
-
1
msas
=
sorted
(
f
for
f
in
os
.
listdir
(
args
.
msa_dir
))
mmcifs
=
sorted
(
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
))
mmcif_idx
=
0
for
f
in
msas
:
if
(
count
==
max_count
):
break
path
=
os
.
path
.
join
(
args
.
msa_dir
,
f
)
name
=
os
.
path
.
splitext
(
f
)[
0
]
spl
=
name
.
upper
().
split
(
'_'
)
if
(
len
(
spl
)
!=
3
):
continue
pdb_id
,
_
,
chain_id
=
spl
while
pdb_id
>
os
.
path
.
splitext
(
mmcifs
[
mmcif_idx
])[
0
].
upper
():
mmcif_idx
+=
1
# Only consider files with matching mmCIF files
if
(
pdb_id
==
os
.
path
.
splitext
(
mmcifs
[
mmcif_idx
])[
0
].
upper
()):
dirname
=
os
.
path
.
join
(
args
.
out_dir
,
'_'
.
join
([
pdb_id
,
chain_id
]))
os
.
makedirs
(
dirname
,
exist_ok
=
True
)
dest
=
os
.
path
.
join
(
dirname
,
f
)
if
(
args
.
copy
):
shutil
.
copyfile
(
path
,
dest
)
else
:
os
.
rename
(
path
,
dest
)
count
+=
1
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"msa_dir"
,
type
=
str
,
help
=
"Directory containing ProteinNet MSAs"
)
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
,
help
=
"Directory containing PDB mmCIFs"
)
parser
.
add_argument
(
"out_dir"
,
type
=
str
,
help
=
"Directory to which output should be saved"
)
parser
.
add_argument
(
"--copy"
,
type
=
bool
,
default
=
True
,
help
=
"Whether to copy the MSAs to out_dir rather than moving them"
)
parser
.
add_argument
(
"--max_count"
,
type
=
int
,
default
=
None
,
help
=
"A bound on the number of MSAs to process"
)
args
=
parser
.
parse_args
()
main
(
args
)
train_openfold.py
View file @
b5f885d0
...
@@ -2,7 +2,7 @@ import argparse
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
logging
import
os
import
os
#os.environ["CUDA_VISIBLE_DEVICES"] = "
6
"
#os.environ["CUDA_VISIBLE_DEVICES"] = "
5
"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
#os.environ["NODE_RANK"]="0"
...
@@ -14,7 +14,6 @@ import time
...
@@ -14,7 +14,6 @@ import time
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.plugins
import
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
import
torch
import
torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment