Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
23886619
"lib/bindings/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "16310b269f866e6f4b7968ba6780e54a4f7b76f6"
Commit
23886619
authored
Oct 18, 2021
by
Gustaf Ahdritz
Browse files
Add mmCIF cache generation script, remove verbose warnings
parent
f649cccd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
40 deletions
+96
-40
openfold/data/data_modules.py
openfold/data/data_modules.py
+12
-7
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+1
-25
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+65
-0
train_openfold.py
train_openfold.py
+18
-8
No files found.
openfold/data/data_modules.py
View file @
23886619
...
@@ -80,7 +80,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -80,7 +80,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if
(
template_release_dates_cache_path
is
None
):
if
(
template_release_dates_cache_path
is
None
):
logging
.
warning
(
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache
s
.py before running OpenFold"
"scripts/generate_mmcif_cache.py before running OpenFold"
)
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
...
@@ -358,6 +358,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -358,6 +358,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
mode
=
"eval"
,
)
)
else
:
self
.
val_dataset
=
None
else
:
else
:
self
.
predict_dataset
=
dataset_gen
(
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
data_dir
=
self
.
predict_data_dir
,
...
@@ -387,12 +389,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -387,12 +389,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
def
val_dataloader
(
self
):
def
val_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
if
(
self
.
val_dataset
is
not
None
):
self
.
val_dataset
,
return
torch
.
utils
.
data
.
DataLoader
(
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
self
.
val_dataset
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
)
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
)
return
None
def
predict_dataloader
(
self
):
def
predict_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
...
...
openfold/data/mmcif_parsing.py
View file @
23886619
...
@@ -345,7 +345,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
...
@@ -345,7 +345,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution
=
parsed_info
[
res_key
][
0
]
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
"resolution"
]
=
float
(
raw_resolution
)
header
[
"resolution"
]
=
float
(
raw_resolution
)
except
ValueError
:
except
ValueError
:
logging
.
warning
(
logging
.
info
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
)
...
@@ -475,27 +475,3 @@ def get_atom_coords(
...
@@ -475,27 +475,3 @@ def get_atom_coords(
all_atom_mask
[
res_index
]
=
mask
all_atom_mask
[
res_index
]
=
mask
return
all_atom_positions
,
all_atom_mask
return
all_atom_positions
,
all_atom_mask
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
f
.
endswith
(
".cif"
):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
"Could not parse
{
f
}
. Skipping..."
)
continue
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
with
open
(
out_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
scripts/generate_mmcif_cache.py
0 → 100644
View file @
23886619
import
argparse
from
functools
import
partial
import
logging
from
multiprocessing
import
Pool
import
os
import
sys
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
from
tqdm
import
tqdm
from
openfold.data.mmcif_parsing
import
parse
def
parse_file
(
f
,
args
):
with
open
(
os
.
path
.
join
(
args
.
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
info
(
f
"Could not parse
{
f
}
. Skipping..."
)
return
{}
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
return
{
file_id
:
local_data
}
def
main
(
args
):
files
=
[
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
)
if
".cif"
in
f
]
fn
=
partial
(
parse_file
,
args
=
args
)
data
=
{}
with
Pool
(
processes
=
args
.
no_workers
)
as
p
:
with
tqdm
(
total
=
len
(
files
))
as
pbar
:
for
d
in
p
.
imap_unordered
(
fn
,
files
,
chunksize
=
args
.
chunksize
):
data
.
update
(
d
)
pbar
.
update
()
with
open
(
args
.
output_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
,
help
=
"Directory containing mmCIF files"
)
parser
.
add_argument
(
"output_path"
,
type
=
str
,
help
=
"Path for .json output"
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of workers to use for parsing"
)
parser
.
add_argument
(
"--chunksize"
,
type
=
int
,
default
=
10
,
help
=
"How many files should be distributed to each worker at a time"
)
args
=
parser
.
parse_args
()
main
(
args
)
train_openfold.py
View file @
23886619
...
@@ -43,15 +43,24 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -43,15 +43,24 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
return
{
"loss"
:
loss
,
"pred"
:
outputs
[
"sm"
][
"positions"
][
-
1
].
detach
()
}
return
{
"loss"
:
loss
}
def
training_epoch_end
(
self
,
outs
):
def
validation_step
(
self
,
batch
,
batch_idx
):
out
=
outs
[
-
1
][
"pred"
].
cpu
()
# At the start of validation, load the EMA weights
with
open
(
"prediction/preds_"
+
str
(
time
.
strftime
(
"%H:%M:%S"
))
+
".pickle"
,
"wb"
)
as
f
:
if
(
self
.
cached_weights
is
None
):
pickle
.
dump
(
out
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
self
.
cached_weights
=
model
.
state_dict
()
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Calculate validation loss
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
return
{
"val_loss"
:
loss
}
#def validation_step(self, batch, batch_idx):
def
validation_epoch_end
(
self
,
_
):
# outputs = self(batch)
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
configure_optimizers
(
self
,
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
learning_rate
:
float
=
1e-3
,
...
@@ -140,7 +149,8 @@ if __name__ == "__main__":
...
@@ -140,7 +149,8 @@ if __name__ == "__main__":
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--template_release_dates_cache_path"
,
type
=
str
,
default
=
None
,
"--template_release_dates_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"Output of templates.generate_mmcif_cache"
help
=
"""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_small_bfd"
,
type
=
bool
,
default
=
False
,
"--use_small_bfd"
,
type
=
bool
,
default
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment