Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
23886619
Commit
23886619
authored
Oct 18, 2021
by
Gustaf Ahdritz
Browse files
Add mmCIF cache generation script, remove verbose warnings
parent
f649cccd
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
40 deletions
+96
-40
openfold/data/data_modules.py
openfold/data/data_modules.py
+12
-7
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+1
-25
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+65
-0
train_openfold.py
train_openfold.py
+18
-8
No files found.
openfold/data/data_modules.py
View file @
23886619
...
@@ -80,7 +80,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -80,7 +80,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if
(
template_release_dates_cache_path
is
None
):
if
(
template_release_dates_cache_path
is
None
):
logging
.
warning
(
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache
s
.py before running OpenFold"
"scripts/generate_mmcif_cache.py before running OpenFold"
)
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
...
@@ -358,6 +358,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -358,6 +358,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
mode
=
"eval"
,
)
)
else
:
self
.
val_dataset
=
None
else
:
else
:
self
.
predict_dataset
=
dataset_gen
(
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
data_dir
=
self
.
predict_data_dir
,
...
@@ -387,6 +389,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -387,6 +389,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
def
val_dataloader
(
self
):
def
val_dataloader
(
self
):
if
(
self
.
val_dataset
is
not
None
):
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
self
.
val_dataset
,
self
.
val_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
...
@@ -394,6 +397,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -394,6 +397,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
)
)
return
None
def
predict_dataloader
(
self
):
def
predict_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
self
.
predict_dataset
,
self
.
predict_dataset
,
...
...
openfold/data/mmcif_parsing.py
View file @
23886619
...
@@ -345,7 +345,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
...
@@ -345,7 +345,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution
=
parsed_info
[
res_key
][
0
]
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
"resolution"
]
=
float
(
raw_resolution
)
header
[
"resolution"
]
=
float
(
raw_resolution
)
except
ValueError
:
except
ValueError
:
logging
.
warning
(
logging
.
info
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
)
...
@@ -475,27 +475,3 @@ def get_atom_coords(
...
@@ -475,27 +475,3 @@ def get_atom_coords(
all_atom_mask
[
res_index
]
=
mask
all_atom_mask
[
res_index
]
=
mask
return
all_atom_positions
,
all_atom_mask
return
all_atom_positions
,
all_atom_mask
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
f
.
endswith
(
".cif"
):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
"Could not parse
{
f
}
. Skipping..."
)
continue
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
with
open
(
out_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
scripts/generate_mmcif_cache.py
0 → 100644
View file @
23886619
import
argparse
from
functools
import
partial
import
logging
from
multiprocessing
import
Pool
import
os
import
sys
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
from
tqdm
import
tqdm
from
openfold.data.mmcif_parsing
import
parse
def
parse_file
(
f
,
args
):
with
open
(
os
.
path
.
join
(
args
.
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
info
(
f
"Could not parse
{
f
}
. Skipping..."
)
return
{}
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
return
{
file_id
:
local_data
}
def
main
(
args
):
files
=
[
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
)
if
".cif"
in
f
]
fn
=
partial
(
parse_file
,
args
=
args
)
data
=
{}
with
Pool
(
processes
=
args
.
no_workers
)
as
p
:
with
tqdm
(
total
=
len
(
files
))
as
pbar
:
for
d
in
p
.
imap_unordered
(
fn
,
files
,
chunksize
=
args
.
chunksize
):
data
.
update
(
d
)
pbar
.
update
()
with
open
(
args
.
output_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
,
help
=
"Directory containing mmCIF files"
)
parser
.
add_argument
(
"output_path"
,
type
=
str
,
help
=
"Path for .json output"
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of workers to use for parsing"
)
parser
.
add_argument
(
"--chunksize"
,
type
=
int
,
default
=
10
,
help
=
"How many files should be distributed to each worker at a time"
)
args
=
parser
.
parse_args
()
main
(
args
)
train_openfold.py
View file @
23886619
...
@@ -43,15 +43,24 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -43,15 +43,24 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
return
{
"loss"
:
loss
,
"pred"
:
outputs
[
"sm"
][
"positions"
][
-
1
].
detach
()
}
return
{
"loss"
:
loss
}
def
training_epoch_end
(
self
,
outs
):
def
validation_step
(
self
,
batch
,
batch_idx
):
out
=
outs
[
-
1
][
"pred"
].
cpu
()
# At the start of validation, load the EMA weights
with
open
(
"prediction/preds_"
+
str
(
time
.
strftime
(
"%H:%M:%S"
))
+
".pickle"
,
"wb"
)
as
f
:
if
(
self
.
cached_weights
is
None
):
pickle
.
dump
(
out
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
self
.
cached_weights
=
model
.
state_dict
()
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
#def validation_step(self, batch, batch_idx):
# Calculate validation loss
# outputs = self(batch)
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
return
{
"val_loss"
:
loss
}
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
configure_optimizers
(
self
,
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
learning_rate
:
float
=
1e-3
,
...
@@ -140,7 +149,8 @@ if __name__ == "__main__":
...
@@ -140,7 +149,8 @@ if __name__ == "__main__":
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--template_release_dates_cache_path"
,
type
=
str
,
default
=
None
,
"--template_release_dates_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"Output of templates.generate_mmcif_cache"
help
=
"""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_small_bfd"
,
type
=
bool
,
default
=
False
,
"--use_small_bfd"
,
type
=
bool
,
default
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment