Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
23886619
Commit
23886619
authored
Oct 18, 2021
by
Gustaf Ahdritz
Browse files
Add mmCIF cache generation script, remove verbose warnings
parent
f649cccd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
40 deletions
+96
-40
openfold/data/data_modules.py
openfold/data/data_modules.py
+12
-7
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+1
-25
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+65
-0
train_openfold.py
train_openfold.py
+18
-8
No files found.
openfold/data/data_modules.py
View file @
23886619
...
...
@@ -80,7 +80,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if
(
template_release_dates_cache_path
is
None
):
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache
s
.py before running OpenFold"
"scripts/generate_mmcif_cache.py before running OpenFold"
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
...
...
@@ -358,6 +358,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
)
else
:
self
.
val_dataset
=
None
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
...
...
@@ -387,12 +389,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
def
val_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
val_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
)
if
(
self
.
val_dataset
is
not
None
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
val_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
)
return
None
def
predict_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
...
...
openfold/data/mmcif_parsing.py
View file @
23886619
...
...
@@ -345,7 +345,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
"resolution"
]
=
float
(
raw_resolution
)
except
ValueError
:
logging
.
warning
(
logging
.
info
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
...
...
@@ -475,27 +475,3 @@ def get_atom_coords(
all_atom_mask
[
res_index
]
=
mask
return
all_atom_positions
,
all_atom_mask
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
f
.
endswith
(
".cif"
):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
"Could not parse
{
f
}
. Skipping..."
)
continue
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
with
open
(
out_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
scripts/generate_mmcif_cache.py
0 → 100644
View file @
23886619
import
argparse
from
functools
import
partial
import
logging
from
multiprocessing
import
Pool
import
os
import
sys
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
from
tqdm
import
tqdm
from
openfold.data.mmcif_parsing
import
parse
def
parse_file
(
f
,
args
):
with
open
(
os
.
path
.
join
(
args
.
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
info
(
f
"Could not parse
{
f
}
. Skipping..."
)
return
{}
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
return
{
file_id
:
local_data
}
def
main
(
args
):
files
=
[
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
)
if
".cif"
in
f
]
fn
=
partial
(
parse_file
,
args
=
args
)
data
=
{}
with
Pool
(
processes
=
args
.
no_workers
)
as
p
:
with
tqdm
(
total
=
len
(
files
))
as
pbar
:
for
d
in
p
.
imap_unordered
(
fn
,
files
,
chunksize
=
args
.
chunksize
):
data
.
update
(
d
)
pbar
.
update
()
with
open
(
args
.
output_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
,
help
=
"Directory containing mmCIF files"
)
parser
.
add_argument
(
"output_path"
,
type
=
str
,
help
=
"Path for .json output"
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of workers to use for parsing"
)
parser
.
add_argument
(
"--chunksize"
,
type
=
int
,
default
=
10
,
help
=
"How many files should be distributed to each worker at a time"
)
args
=
parser
.
parse_args
()
main
(
args
)
train_openfold.py
View file @
23886619
...
...
@@ -43,15 +43,24 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
return
{
"loss"
:
loss
,
"pred"
:
outputs
[
"sm"
][
"positions"
][
-
1
].
detach
()
}
return
{
"loss"
:
loss
}
def
training_epoch_end
(
self
,
outs
):
out
=
outs
[
-
1
][
"pred"
].
cpu
()
with
open
(
"prediction/preds_"
+
str
(
time
.
strftime
(
"%H:%M:%S"
))
+
".pickle"
,
"wb"
)
as
f
:
pickle
.
dump
(
out
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
self
.
cached_weights
=
model
.
state_dict
()
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Calculate validation loss
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
return
{
"val_loss"
:
loss
}
#def validation_step(self, batch, batch_idx):
# outputs = self(batch)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
...
...
@@ -140,7 +149,8 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--template_release_dates_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"Output of templates.generate_mmcif_cache"
help
=
"""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
parser
.
add_argument
(
"--use_small_bfd"
,
type
=
bool
,
default
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment