"lib/bindings/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "16310b269f866e6f4b7968ba6780e54a4f7b76f6"
Commit 23886619 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add mmCIF cache generation script, remove verbose warnings

parent f649cccd
...@@ -80,7 +80,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -80,7 +80,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(template_release_dates_cache_path is None): if(template_release_dates_cache_path is None):
logging.warning( logging.warning(
"Template release dates cache does not exist. Remember to run " "Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_caches.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
...@@ -358,6 +358,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -358,6 +358,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits=self.config.eval.max_template_hits, max_template_hits=self.config.eval.max_template_hits,
mode="eval", mode="eval",
) )
else:
self.val_dataset = None
else: else:
self.predict_dataset = dataset_gen( self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir, data_dir=self.predict_data_dir,
...@@ -387,12 +389,15 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -387,12 +389,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
def val_dataloader(self): def val_dataloader(self):
return torch.utils.data.DataLoader( if(self.val_dataset is not None):
self.val_dataset, return torch.utils.data.DataLoader(
batch_size=self.config.data_module.data_loaders.batch_size, self.val_dataset,
num_workers=self.config.data_module.data_loaders.num_workers, batch_size=self.config.data_module.data_loaders.batch_size,
collate_fn=self._gen_batch_collator("eval") num_workers=self.config.data_module.data_loaders.num_workers,
) collate_fn=self._gen_batch_collator("eval")
)
return None
def predict_dataloader(self): def predict_dataloader(self):
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
......
...@@ -345,7 +345,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: ...@@ -345,7 +345,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution = parsed_info[res_key][0] raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution) header["resolution"] = float(raw_resolution)
except ValueError: except ValueError:
logging.warning( logging.info(
"Invalid resolution format: %s", parsed_info[res_key] "Invalid resolution format: %s", parsed_info[res_key]
) )
...@@ -475,27 +475,3 @@ def get_atom_coords( ...@@ -475,27 +475,3 @@ def get_atom_coords(
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
return all_atom_positions, all_atom_mask return all_atom_positions, all_atom_mask
def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {}
for f in os.listdir(mmcif_dir):
if f.endswith(".cif"):
with open(os.path.join(mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if mmcif.mmcif_object is None:
logging.warning(f"Could not parse {f}. Skipping...")
continue
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data
with open(out_path, "w") as fp:
fp.write(json.dumps(data))
import argparse
from functools import partial
import logging
from multiprocessing import Pool
import os
import sys
sys.path.append(".") # an innocent hack to get this to run from the top level
from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
def parse_file(f, args):
with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if mmcif.mmcif_object is None:
logging.info(f"Could not parse {f}. Skipping...")
return {}
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
return {file_id: local_data}
def main(args):
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
data.update(d)
pbar.update()
with open(args.output_path, "w") as fp:
fp.write(json.dumps(data))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"mmcif_dir", type=str, help="Directory containing mmCIF files"
)
parser.add_argument(
"output_path", type=str, help="Path for .json output"
)
parser.add_argument(
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
)
args = parser.parse_args()
main(args)
...@@ -43,15 +43,24 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -43,15 +43,24 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
return {"loss": loss, "pred": outputs["sm"]["positions"][-1].detach()} return {"loss": loss}
def training_epoch_end(self, outs): def validation_step(self, batch, batch_idx):
out = outs[-1]["pred"].cpu() # At the start of validation, load the EMA weights
with open("prediction/preds_" + str(time.strftime("%H:%M:%S")) + ".pickle", "wb") as f: if(self.cached_weights is None):
pickle.dump(out, f, protocol=pickle.HIGHEST_PROTOCOL) self.cached_weights = model.state_dict()
self.model.load_state_dict(self.ema.state_dict()["params"])
# Calculate validation loss
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch)
return {"val_loss": loss}
#def validation_step(self, batch, batch_idx): def validation_epoch_end(self, _):
# outputs = self(batch) # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def configure_optimizers(self, def configure_optimizers(self,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
...@@ -140,7 +149,8 @@ if __name__ == "__main__": ...@@ -140,7 +149,8 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--template_release_dates_cache_path", type=str, default=None, "--template_release_dates_cache_path", type=str, default=None,
help="Output of templates.generate_mmcif_cache" help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
) )
parser.add_argument( parser.add_argument(
"--use_small_bfd", type=bool, default=False, "--use_small_bfd", type=bool, default=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment