Unverified Commit f434a278 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #439 from aqlaboratory/setup-improvements

Adds Documentation and minor quality of life fixes
parents 3eef7caa d8117ce3
This diff is collapsed.
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import logging
import argparse
import os
import shutil
import torch
from openfold.utils.import_weights import convert_deprecated_v1_keys
from deepspeed.utils.zero_to_fp32 import (
get_optim_files, parse_optim_states, get_model_state_file
)
def convert_v1_to_v2_weights(args):
checkpoint_path = args.input_ckpt_path
is_dir = os.path.isdir(checkpoint_path)
if is_dir:
# A DeepSpeed checkpoint
logging.info(
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'module'
latest_path = os.path.join(checkpoint_path, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_path, tag)
model_output_path = os.path.join(args.output_ckpt_path, tag)
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
else:
# A Pytorch Lightning checkpoint
logging.info(
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'state_dict'
model_output_path = args.output_ckpt_path
model_file = checkpoint_path
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
model_dict[state_dict_key] = convert_deprecated_v1_keys(
model_dict[state_dict_key])
if 'ema' in model_dict:
ema_state_dict = model_dict['ema']['params']
model_dict['ema']['params'] = convert_deprecated_v1_keys(
ema_state_dict)
if is_dir:
param_shapes = convert_deprecated_v1_keys(
model_dict['param_shapes'][0])
model_dict['param_shapes'] = [param_shapes]
shutil.copytree(checkpoint_path, args.output_ckpt_path)
out_fname = os.path.join(
model_output_path, os.path.basename(model_file))
for optim_file in optim_files:
optim_dict = torch.load(optim_file)
new_optim_dict = optim_dict.copy()
new_optim_dict['optimizer_state_dict']['param_slice_mappings'][0] = convert_deprecated_v1_keys(
optim_dict['optimizer_state_dict']['param_slice_mappings'][0])
out_optim_fname = os.path.join(
model_output_path, os.path.basename(optim_file))
torch.save(new_optim_dict, out_optim_fname)
else:
out_fname = model_output_path
torch.save(model_dict, out_fname)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_ckpt_path", type=str)
parser.add_argument("output_ckpt_path", type=str)
args = parser.parse_args()
convert_v1_to_v2_weights(args)
"""
The OpenProteinSet alignment database is non-redundant, meaning that it only
stores one explicit representative alignment directory for all PDB chains in a
100% sequence identity cluster. In order to add explicit alignments for all PDB
chains, this script will add the missing chain directories and symlink them to
their representative alignment directories. This is required in order to train
OpenFold on the full PDB, not just one representative chain per cluster.
"""
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm
def create_duplicate_dirs(duplicate_chains: list[list[str]], alignment_dir: Path):
"""
Create duplicate directory symlinks for all chains in the given duplicate lists.
Args:
duplicate_lists (list[list[str]]): A list of lists, where each inner list
contains chains that are 100% sequence identical.
alignment_dir (Path): Path to flattened alignment directory, with one
subdirectory per chain.
"""
print("Creating duplicate directory symlinks...")
dirs_created = 0
for chains in tqdm(duplicate_chains):
# find the chain that has an alignment
for chain in chains:
if (alignment_dir / chain).exists():
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue
# create symlinks for all other chains
for chain in chains:
if chain != representative_chain:
target_path = alignment_dir / chain
if target_path.exists():
print(f"Chain {chain} already exists, skipping...")
else:
(target_path).symlink_to(alignment_dir / representative_chain)
dirs_created += 1
print(f"Created directories for {dirs_created} duplicate chains.")
def main(alignment_dir: Path, duplicate_chains_file: Path):
# read duplicate chains file
with open(duplicate_chains_file, "r") as fp:
duplicate_chains = [list(line.strip().split()) for line in fp]
# convert to absolute path for symlink creation
alignment_dir = alignment_dir.resolve()
create_duplicate_dirs(duplicate_chains, alignment_dir)
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to flattened alignment directory, with one subdirectory
per chain.""",
)
parser.add_argument(
"duplicate_chains_file",
type=Path,
help="""Path to file containing duplicate chains, where each line
contains a space-separated list of chains that are 100%%
sequence identical.
""",
)
args = parser.parse_args()
main(args.alignment_dir, args.duplicate_chains_file)
This diff is collapsed.
import argparse
import ctypes
from datetime import date
import os
import sys
from pathlib import Path
if 'CONDA_PREFIX' in os.environ:
CONDA_ENV_BINARY_PATH= Path(os.environ['CONDA_PREFIX']) / 'bin'
else:
CONDA_ENV_BINARY_PATH = Path('/bin')
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
......@@ -30,22 +36,22 @@ def add_data_args(parser: argparse.ArgumentParser):
'--bfd_database_path', type=str, default=None,
)
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
'--jackhmmer_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'jackhmmer'),
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
'--hhblits_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hhblits'),
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
'--hhsearch_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hhsearch'),
)
parser.add_argument(
'--hmmsearch_binary_path', type=str, default='/usr/bin/hmmsearch'
'--hmmsearch_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hmmsearch'),
)
parser.add_argument(
'--hmmbuild_binary_path', type=str, default='/usr/bin/hmmbuild'
'--hmmbuild_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hmmbuild'),
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
'--kalign_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'kalign'),
)
parser.add_argument(
'--max_template_date', type=str,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment