Unverified Commit 7d227395 authored by Jannik Gut's avatar Jannik Gut Committed by GitHub
Browse files

Merge branch 'main' into main

parents b38b6078 f37d0d96
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import pickle import pickle
import random import random
import time import time
import json
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
...@@ -131,7 +132,16 @@ def generate_feature_dict( ...@@ -131,7 +132,16 @@ def generate_feature_dict(
args, args,
): ):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
if len(seqs) == 1:
if "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
elif len(seqs) == 1:
tag = tags[0] tag = tags[0]
seq = seqs[0] seq = seqs[0]
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
...@@ -143,14 +153,6 @@ def generate_feature_dict( ...@@ -143,14 +153,6 @@ def generate_feature_dict(
alignment_dir=local_alignment_dir, alignment_dir=local_alignment_dir,
seqemb_mode=args.use_single_seq_mode, seqemb_mode=args.use_single_seq_mode,
) )
elif "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
else: else:
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write( fp.write(
...@@ -177,7 +179,21 @@ def main(args): ...@@ -177,7 +179,21 @@ def main(args):
if args.config_preset.startswith("seq"): if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(
args.config_preset,
long_sequence_inference=args.long_sequence_inference,
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.trace_model: if args.trace_model:
if not config.data.predict.fixed_size: if not config.data.predict.fixed_size:
...@@ -261,6 +277,11 @@ def main(args): ...@@ -261,6 +277,11 @@ def main(args):
seq_sort_fn = lambda target: sum([len(s) for s in target[1]]) seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn) sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {} feature_dicts = {}
if is_multimer and args.openfold_checkpoint_path:
raise ValueError(
'`openfold_checkpoint_path` was specified, but no OpenFold checkpoints are available for multimer mode')
model_generator = load_models_from_command_line( model_generator = load_models_from_command_line(
config, config,
args.model_device, args.model_device,
...@@ -459,6 +480,13 @@ if __name__ == "__main__": ...@@ -459,6 +480,13 @@ if __name__ == "__main__":
"--cif_output", action="store_true", default=False, "--cif_output", action="store_true", default=False,
help="Output predicted models in ModelCIF format instead of PDB format (default)" help="Output predicted models in ModelCIF format instead of PDB format (default)"
) )
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser.add_argument(
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
"""
This script generates a FASTA file for all chains in an alignment directory or
alignment DB.
"""
import json
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from tqdm import tqdm
def chain_dir_to_fasta(dir: Path) -> str:
"""
Generates a FASTA string from a chain directory.
"""
# take some alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
alignment_file = dir / alignment_file_type
if alignment_file.exists():
break
with open(alignment_file, "r") as f:
next(f) # skip the first line
seq = next(f).strip()
try:
next_line = next(f)
except StopIteration:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
chain_id = dir.name
return f">{chain_id}\n{seq}\n"
def index_entry_to_fasta(index_entry: dict, db_dir: Path, chain_id: str) -> str:
"""
Generates a FASTA string from an alignment-db index entry.
"""
db_file = db_dir / index_entry["db"]
# look for an alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
for file_info in index_entry["files"]:
if file_info[0] == alignment_file_type:
start, size = file_info[1], file_info[2]
break
with open(db_file, "rb") as f:
f.seek(start)
msa_lines = f.read(size).decode("utf-8").splitlines()
seq = msa_lines[1]
try:
next_line = msa_lines[2]
except IndexError:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
return f">{chain_id}\n{seq}\n"
def main(
output_path: Path, alignment_db_index: Optional[Path], alignment_dir: Optional[Path]
) -> None:
"""
Generate a FASTA file from either an alignment-db index or a chain directory using multi-threading.
"""
fasta = []
if alignment_dir and alignment_db_index:
raise ValueError(
"Only one of alignment_db_index and alignment_dir can be provided."
)
if alignment_dir:
print("Creating FASTA from alignment directory...")
chain_dirs = list(alignment_dir.iterdir())
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chain_dir_to_fasta, chain_dir)
for chain_dir in chain_dirs
]
for future in tqdm(as_completed(futures), total=len(chain_dirs)):
fasta.append(future.result())
elif alignment_db_index:
print("Creating FASTA from alignment dbs...")
with open(alignment_db_index, "r") as f:
index = json.load(f)
db_dir = alignment_db_index.parent
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(index_entry_to_fasta, index_entry, db_dir, chain_id)
for chain_id, index_entry in index.items()
]
for future in tqdm(as_completed(futures), total=len(index)):
fasta.append(future.result())
else:
raise ValueError("Either alignment_db_index or alignment_dir must be provided.")
with open(output_path, "w") as f:
f.write("".join(fasta))
print(f"FASTA file written to {output_path}.")
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"output_path",
type=Path,
help="Path to output FASTA file.",
)
parser.add_argument(
"--alignment_db_index",
type=Path,
help="Path to alignment-db index file.",
)
parser.add_argument(
"--alignment_dir",
type=Path,
help="Path to alignment directory.",
)
args = parser.parse_args()
main(args.output_path, args.alignment_db_index, args.alignment_dir)
from argparse import ArgumentParser
from pathlib import Path
import json
def main(args):
# get the super index
with open(args.alignment_db_super_index_path, "r") as fp:
super_index = json.load(fp)
# get all chains and sequences
chains_to_seqs = {}
with open(args.all_chains_fasta, "r") as fp:
lines = fp.readlines()
# iterate through chain-sequence pairs
for chain_idx in range(0, len(lines), 2):
chain = lines[chain_idx][1:].strip()
seq = lines[chain_idx + 1].strip()
chains_to_seqs[chain] = seq
chains_w_alignments = set(super_index.keys())
chains_wo_alignments = set(chains_to_seqs.keys()) - chains_w_alignments
seq_to_chain_w_alignment = {
chains_to_seqs[chain]: chain for chain in chains_w_alignments
}
print("Unique sequences with alignments:", len(seq_to_chain_w_alignment))
# map chain without alignment to alignment entry of another chain with the
# same sequence
remaining_unaligned_chains = []
for chain in chains_wo_alignments:
seq = chains_to_seqs[chain]
try:
corresponding_alignment = super_index[seq_to_chain_w_alignment[seq]]
# no corresponding chain with alignment found
except KeyError:
remaining_unaligned_chains.append(chain)
continue
super_index[chain] = corresponding_alignment
with open(args.output_path, "w") as fp:
json.dump(super_index, fp)
print(
f"No corresponding alignment found for the following {len(remaining_unaligned_chains)} chains:",
remaining_unaligned_chains,
)
if __name__ == "__main__":
parser = ArgumentParser(
description="""
If the alignment-db index was created on unique-chain alignments only,
this will add the missing chain entries to the super-index file based on
a .fasta file that contains sequences for all chains.
Note that this only modifies the index and not the database itself, as
the duplicate sequences will just point to the same alignments.
"""
)
parser.add_argument(
"alignment_db_super_index_path",
type=Path,
help="Path to alignment-db super index file.",
)
parser.add_argument(
"output_path", type=Path, help="Write the output super index to this path."
)
parser.add_argument(
"all_chains_fasta",
type=Path,
help="Path to the fasta file containing sequences for all chains.",
)
args = parser.parse_args()
main(args)
"""
This is a modified version of the create_alignment_db.py script in OpenFold
which supports sharding into multiple files. The created index is already a
super index, meaning that "unify_alignment_db_indices.py" does not need to be
run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version.
"""
import argparse
import json
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from math import ceil
from multiprocessing import cpu_count
from pathlib import Path
from tqdm import tqdm
def split_file_list(file_list: list[Path], n_shards: int):
"""
Split up the total file list into n_shards sublists.
"""
split_list = []
for i in range(n_shards):
split_list.append(file_list[i::n_shards])
assert len([f for sublist in split_list for f in sublist]) == len(file_list)
return split_list
def chunked_iterator(lst: list, chunk_size: int):
"""Iterate over a list in chunks of size chunk_size."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def read_chain_dir(chain_dir: Path) -> dict:
"""
Read all alignment files in a single chain directory and return a dict
mapping chain name to file names and bytes.
"""
if not chain_dir.is_dir():
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
# ensure that PDB IDs are all lowercase
pdb_id, chain = chain_dir.name.split("_")
pdb_id = pdb_id.lower()
chain_name = f"{pdb_id}_{chain}"
file_data = []
for file_path in sorted(chain_dir.iterdir()):
file_name = file_path.name
with open(file_path, "rb") as file:
file_bytes = file.read()
file_data.append((file_name, file_bytes))
return {chain_name: file_data}
def process_chunk(chain_files: list[Path]) -> dict:
"""
Returns the file names and bytes for all chains in a chunk of files.
"""
chunk_data = {}
with ThreadPoolExecutor() as executor:
for file_data in executor.map(read_chain_dir, chain_files):
chunk_data.update(file_data)
return chunk_data
def create_index_default_dict() -> dict:
"""
Returns a default dict for the index entries).
"""
return {"db": None, "files": []}
def create_shard(
shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int
) -> dict:
"""
Creates a single shard of the alignment database, and returns the
corresponding indices for the super index.
"""
CHUNK_SIZE = 200
shard_index = defaultdict(
create_index_default_dict
) # e.g. {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
pbar_desc = f"Shard {shard_num}"
output_path = output_dir / f"{output_name}_{shard_num}.db"
db_offset = 0
db_file = open(output_path, "wb")
for files_chunk in tqdm(
chunk_iter,
total=ceil(len(shard_files) / CHUNK_SIZE),
desc=pbar_desc,
position=shard_num,
leave=False,
):
# get processed files for one chunk
chunk_data = process_chunk(files_chunk)
# write to db and store info in index
for chain_name, file_data in chunk_data.items():
shard_index[chain_name]["db"] = output_path.name
for file_name, file_bytes in file_data:
file_length = len(file_bytes)
shard_index[chain_name]["files"].append(
(file_name, db_offset, file_length)
)
db_file.write(file_bytes)
db_offset += file_length
db_file.close()
return shard_index
def main(args):
alignment_dir = args.alignment_dir
output_dir = args.output_db_path
output_dir.mkdir(exist_ok=True, parents=True)
output_db_name = args.output_db_name
n_shards = args.n_shards
n_cpus = cpu_count()
if n_shards > n_cpus:
print(
f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). "
"This may result in slower performance. Consider using a smaller number of shards."
)
# get all chain dirs in alignment_dir
print("Getting chain directories...")
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
# split chain dirs into n_shards sublists
chain_dir_shards = split_file_list(all_chain_dirs, n_shards)
# total index for all shards
super_index = {}
# create a shard for each sublist
print(f"Creating {n_shards} alignment-db files...")
with ProcessPoolExecutor() as executor:
futures = [
executor.submit(
create_shard, shard_files, output_dir, output_db_name, shard_index
)
for shard_index, shard_files in enumerate(chain_dir_shards)
]
for future in as_completed(futures):
shard_index = future.result()
super_index.update(shard_index)
print("\nCreated all shards.")
if args.duplicate_chains_file:
print("Extending super index with duplicate chains...")
duplicates_added = 0
with open(args.duplicate_chains_file, "r") as fp:
duplicate_chains = [line.strip().split() for line in fp]
for chains in duplicate_chains:
# find representative with alignment
for chain in chains:
if chain in super_index:
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue
# add duplicates to index
for chain in chains:
if chain != representative_chain:
super_index[chain] = super_index[representative_chain]
duplicates_added += 1
print(f"Added {duplicates_added} duplicate chains to index.")
# write super index to file
print("\nWriting super index...")
index_path = output_dir / f"{output_db_name}.index"
with open(index_path, "w") as fp:
json.dump(super_index, fp, indent=4)
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script creates an alignment database format from a directory of
precomputed alignments. For better file system health, the total
database is split into n_shards files, where each shard contains a
subset of the total alignments. The output is a directory containing the
n_shards database files, and a single index file mapping chain names to
the database file and byte offsets for each alignment file.
Note: For optimal performance, your machine should have at least as many
cores as shards you want to create.
"""
)
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to precomputed flattened alignment directory, with one
subdirectory per chain.""",
)
parser.add_argument("output_db_path", type=Path)
parser.add_argument("output_db_name", type=str)
parser.add_argument(
"--n_shards",
type=int,
help="Number of shards to split the database into",
default=10,
)
parser.add_argument(
"--duplicate_chains_file",
type=Path,
help="""
Optional path to file containing duplicate chain information, where each
line contains chains that are 100% sequence identical. If provided,
duplicate chains will be added to the index and point to the same
underlying database entry as their representatives in the alignment dir.
""",
default=None,
)
args = parser.parse_args()
main(args)
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import logging
import argparse
import os
import shutil
import torch
from openfold.utils.import_weights import convert_deprecated_v1_keys
from deepspeed.utils.zero_to_fp32 import (
get_optim_files, parse_optim_states, get_model_state_file
)
def convert_v1_to_v2_weights(args):
checkpoint_path = args.input_ckpt_path
is_dir = os.path.isdir(checkpoint_path)
if is_dir:
# A DeepSpeed checkpoint
logging.info(
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'module'
latest_path = os.path.join(checkpoint_path, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_path, tag)
model_output_path = os.path.join(args.output_ckpt_path, tag)
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
else:
# A Pytorch Lightning checkpoint
logging.info(
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'state_dict'
model_output_path = args.output_ckpt_path
model_file = checkpoint_path
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
model_dict[state_dict_key] = convert_deprecated_v1_keys(
model_dict[state_dict_key])
if 'ema' in model_dict:
ema_state_dict = model_dict['ema']['params']
model_dict['ema']['params'] = convert_deprecated_v1_keys(
ema_state_dict)
if is_dir:
param_shapes = convert_deprecated_v1_keys(
model_dict['param_shapes'][0])
model_dict['param_shapes'] = [param_shapes]
shutil.copytree(checkpoint_path, args.output_ckpt_path)
out_fname = os.path.join(
model_output_path, os.path.basename(model_file))
for optim_file in optim_files:
optim_dict = torch.load(optim_file)
new_optim_dict = optim_dict.copy()
new_optim_dict['optimizer_state_dict']['param_slice_mappings'][0] = convert_deprecated_v1_keys(
optim_dict['optimizer_state_dict']['param_slice_mappings'][0])
out_optim_fname = os.path.join(
model_output_path, os.path.basename(optim_file))
torch.save(new_optim_dict, out_optim_fname)
else:
out_fname = model_output_path
torch.save(model_dict, out_fname)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_ckpt_path", type=str)
parser.add_argument("output_ckpt_path", type=str)
args = parser.parse_args()
convert_v1_to_v2_weights(args)
#!/bin/bash #!/bin/bash
# #
# Copyright 2021 DeepMind Technologies Limited # Copyright 2024 AlQuraishi Laboratory
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
# Downloads OpenFold parameters. # Downloads OpenFold SoloSeq (single sequence model) parameters.
# #
# Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory # Usage: bash download_openfold_soloseq_params.sh /path/to/download/directory
set -e set -e
if [[ $# -eq 0 ]]; then if [[ $# -eq 0 ]]; then
......
#!/bin/bash
#
# Copyright 2024 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads ESM-1b embeddings used to train OpenFold SoloSeq single-seq model.
#
# Usage: bash download_soloseq_embeddings.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aws &> /dev/null ; then
echo "Error: aws could not be found. Please install aws."
exit 1
fi
DOWNLOAD_DIR="${1}/soloseq_embeddings"
mkdir -p "${DOWNLOAD_DIR}"
aws s3 cp --no-sign-request --region us-east-1 s3://openfold/soloseq_embeddings/ "${DOWNLOAD_DIR}" --recursive
"""
The OpenProteinSet alignment database is non-redundant, meaning that it only
stores one explicit representative alignment directory for all PDB chains in a
100% sequence identity cluster. In order to add explicit alignments for all PDB
chains, this script will add the missing chain directories and symlink them to
their representative alignment directories. This is required in order to train
OpenFold on the full PDB, not just one representative chain per cluster.
"""
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm
def create_duplicate_dirs(duplicate_chains: list[list[str]], alignment_dir: Path):
"""
Create duplicate directory symlinks for all chains in the given duplicate lists.
Args:
duplicate_lists (list[list[str]]): A list of lists, where each inner list
contains chains that are 100% sequence identical.
alignment_dir (Path): Path to flattened alignment directory, with one
subdirectory per chain.
"""
print("Creating duplicate directory symlinks...")
dirs_created = 0
for chains in tqdm(duplicate_chains):
# find the chain that has an alignment
for chain in chains:
if (alignment_dir / chain).exists():
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue
# create symlinks for all other chains
for chain in chains:
if chain != representative_chain:
target_path = alignment_dir / chain
if target_path.exists():
print(f"Chain {chain} already exists, skipping...")
else:
(target_path).symlink_to(alignment_dir / representative_chain)
dirs_created += 1
print(f"Created directories for {dirs_created} duplicate chains.")
def main(alignment_dir: Path, duplicate_chains_file: Path):
# read duplicate chains file
with open(duplicate_chains_file, "r") as fp:
duplicate_chains = [list(line.strip().split()) for line in fp]
# convert to absolute path for symlink creation
alignment_dir = alignment_dir.resolve()
create_duplicate_dirs(duplicate_chains, alignment_dir)
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to flattened alignment directory, with one subdirectory
per chain.""",
)
parser.add_argument(
"duplicate_chains_file",
type=Path,
help="""Path to file containing duplicate chains, where each line
contains a space-separated list of chains that are 100%%
sequence identical.
""",
)
args = parser.parse_args()
main(args.alignment_dir, args.duplicate_chains_file)
"""
This script takes a .fasta file as input and then clusters it on a given
sequence identity threshold using mmseqs2. The mmseqs2 flags are identical to
what PDB officially uses to provide their official sequence clusters
(https://github.com/soedinglab/MMseqs2/issues/452).
"""
import shutil
import subprocess
from argparse import ArgumentParser
from collections import defaultdict
from pathlib import Path
def reformat_cluster_file(cluster_file: Path, output_file: Path):
"""
This function takes a mmseqs2 output cluster file and reformats it to a text
file where each line contains a space-separated list of {PDB_ID}_{CHAIN_ID}
belonging to the same cluster.
"""
cluster_to_chains = defaultdict(list)
# extract all chains belonging to each cluster
with open(cluster_file, "r") as f:
for line in f:
line = line.strip()
cluster_name, chain_id = line.split()
cluster_to_chains[cluster_name].append(chain_id)
# write all chains belonging to the same cluster on the same line
with open(output_file, "w") as f:
for chains in cluster_to_chains.values():
f.write(f"{' '.join(chains)}\n")
def main(args):
input_file = args.input_fasta.absolute()
output_file = args.output_file.absolute()
output_dir = args.output_file.parent
# prefix that all output files get
mmseqs_prefix = "_mmseqs_out"
# temporary directory that mmseqs2 uses
tmp_name = f"{mmseqs_prefix}_temp"
tmp_dir = output_dir / tmp_name
mmseqs_command = [
args.mmseqs_binary_path,
"easy-cluster",
input_file,
mmseqs_prefix,
tmp_name,
"--min-seq-id",
str(args.seq_id),
"-c",
"0.9",
"-s",
"8",
"--max-seqs",
"1000",
"--cluster-mode",
"1",
]
# run mmseqs with PDB settings
print("Running mmseqs2...")
subprocess.run(mmseqs_command, check=True, cwd=output_dir)
cluster_file = output_dir / "_mmseqs_out_cluster.tsv"
print("Reformatting output file...")
reformat_cluster_file(cluster_file, output_file)
print("Cleaning up mmseqs2 output...")
mmseqs_outputs = [
output_dir / f"{mmseqs_prefix}_{suffix}"
for suffix in ["cluster.tsv", "rep_seq.fasta", "all_seqs.fasta"]
]
for file in mmseqs_outputs:
file.unlink()
shutil.rmtree(tmp_dir)
print("Done!")
if __name__ == "__main__":
parser = ArgumentParser(
description=__doc__
)
parser.add_argument(
"input_fasta",
type=Path,
help="Input .fasta file. Sequence names should be in format >{PDB_ID}_{CHAIN_ID}",
)
parser.add_argument(
"output_file",
type=Path,
help="Output file. Each line will contain a space-separated list of {PDB_ID}_{CHAIN_ID} belonging to the same cluster.",
)
parser.add_argument("mmseqs_binary_path", type=str, help="Path to mmseqs binary")
parser.add_argument("--seq-id", type=float, default=0.4, help="Sequence identity threshold for clustering.")
args = parser.parse_args()
main(args)
import argparse import argparse
import ctypes import ctypes
from datetime import date from datetime import date
import os
import sys import sys
from pathlib import Path
if 'CONDA_PREFIX' in os.environ:
CONDA_ENV_BINARY_PATH= Path(os.environ['CONDA_PREFIX']) / 'bin'
else:
CONDA_ENV_BINARY_PATH = Path('/bin')
def add_data_args(parser: argparse.ArgumentParser): def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
...@@ -30,22 +36,22 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -30,22 +36,22 @@ def add_data_args(parser: argparse.ArgumentParser):
'--bfd_database_path', type=str, default=None, '--bfd_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer' '--jackhmmer_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'jackhmmer'),
) )
parser.add_argument( parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits' '--hhblits_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hhblits'),
) )
parser.add_argument( parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch' '--hhsearch_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hhsearch'),
) )
parser.add_argument( parser.add_argument(
'--hmmsearch_binary_path', type=str, default='/usr/bin/hmmsearch' '--hmmsearch_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hmmsearch'),
) )
parser.add_argument( parser.add_argument(
'--hmmbuild_binary_path', type=str, default='/usr/bin/hmmbuild' '--hmmbuild_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hmmbuild'),
) )
parser.add_argument( parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign' '--kalign_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'kalign'),
) )
parser.add_argument( parser.add_argument(
'--max_template_date', type=str, '--max_template_date', type=str,
......
#!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py . pytorch_model.bin
import argparse
import torch
import glob
import math
import os
from collections import OrderedDict
import re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
import deepspeed
from deepspeed.utils import logger
debug = 0
# load to cpu
device = torch.device('cpu')
def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file
if zero_stage == 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
if not os.path.exists(file):
raise FileNotFoundError(f"can't find model states file at '{file}'")
return file
def get_optim_files(checkpoint_dir):
# XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")))
if len(optim_files) == 0:
raise FileNotFoundError(
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
return optim_files
def parse_model_state(file):
state_dict = torch.load(file, map_location=device)
if "buffer_names" not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict["buffer_names"]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {
k: v.float()
for k,
v in state_dict["module"].items() if k in buffer_names
}
return buffers
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dicts.append(torch.load(f, map_location=device))
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
param_shapes = state_dicts[0]["param_shapes"]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.
if type(world_size) is list:
world_size = max(world_size)
if world_size != total_files:
raise ValueError(
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
)
# the groups are named differently in each stage
if zero_stage == 2:
fp32_groups_key = "single_partition_of_fp32_groups"
elif zero_stage == 3:
fp32_groups_key = "fp32_flat_groups"
else:
raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage == 2:
fp32_flat_groups = [
state_dicts[i]['optimizer_state_dict'][fp32_groups_key]
for i in range(len(state_dicts))
]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
0) for i in range(len(state_dicts))
]
return zero_stage, world_size, param_shapes, fp32_flat_groups
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
"""
Returns fp32 state_dict reconstructed from ds checkpoint
Args:
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
"""
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print(
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
buffers = parse_model_state(model_file)
if zero_stage == 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers)
def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers):
# Reconstruction protocol:
#
# XXX: document this
if debug:
for i in range(world_size):
for j in range(len(fp32_flat_groups[0])):
print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
# XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0])
merged_single_partition_of_fp32_groups = []
for i in range(num_param_groups):
merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum([
full_single_fp32_vector.numel()
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
])
if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum(
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
total_numel = 0
total_params = 0
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
offset = 0
avail_numel = full_single_fp32_vector.numel()
for name, shape in shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
if debug:
print(
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
)
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
unpartitioned_numel).view(shape)
offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
# live optimizer object, so we are checking that the numbers are within the right range
align_to = 2 * world_size
def zero2_align(x):
return align_to * math.ceil(x / align_to)
if debug:
print(f"original offset={offset}, avail_numel={avail_numel}")
offset = zero2_align(offset)
avail_numel = zero2_align(avail_numel)
if debug:
print(f"aligned offset={offset}, avail_numel={avail_numel}")
# Sanity check
if offset != avail_numel:
raise ValueError(
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
return state_dict
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel
def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers):
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
avail_numel = fp32_flat_groups[0].numel() * world_size
# merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
if debug:
for i in range(world_size):
print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}")
wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).narrow(0,
0,
unpartitioned_numel).view(shape)
offset += partitioned_numel
offset *= world_size
# Sanity check
if offset != avail_numel:
raise ValueError(
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
via a model hub.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
Returns:
- pytorch ``state_dict``
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others
In this example the ``model`` will no longer be usable in the deepspeed context of the same
application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
"""
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
"""
1. Put the provided model to cpu
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
3. Load it into the provided model
Args:
- ``model``: the model object to update
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
Returns:
- ``model`: modified model
Make sure you have plenty of CPU memory available before you call this function. If you don't
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
conveniently placed for you in the checkpoint folder.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
"""
logger.info(f"Extracting fp32 weights")
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
logger.info(f"Overwriting model with fp32 weights")
model = model.cpu()
model.load_state_dict(state_dict, strict=False)
return model
def get_global_step_from_zero_checkpoint(checkpoint_dir):
global_step = -1
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
match = re.match(r"global_step([0-9]+)", tag)
global_step = int(match.group(1))
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return global_step
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_dir",
type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument(
"output_file",
type=str,
help=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args()
debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
...@@ -48,7 +48,8 @@ class TestPermutation(unittest.TestCase): ...@@ -48,7 +48,8 @@ class TestPermutation(unittest.TestCase):
self.chain_a_num_res = 9 self.chain_a_num_res = 9
self.chain_b_num_res = 13 self.chain_b_num_res = 13
# below create default fake ground truth structures for a hetero-pentamer A2B3 # below create default fake ground truth structures for a hetero-pentamer A2B3
self.residue_index = list(range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3 self.residue_index = list(
range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3
self.num_res = self.chain_a_num_res * 2 + self.chain_b_num_res * 3 self.num_res = self.chain_a_num_res * 2 + self.chain_b_num_res * 3
self.asym_id = torch.tensor([[1] * self.chain_a_num_res + [2] * self.chain_a_num_res + [ self.asym_id = torch.tensor([[1] * self.chain_a_num_res + [2] * self.chain_a_num_res + [
3] * self.chain_b_num_res + [4] * self.chain_b_num_res + [5] * self.chain_b_num_res], device=device) 3] * self.chain_b_num_res + [4] * self.chain_b_num_res + [5] * self.chain_b_num_res], device=device)
...@@ -63,19 +64,44 @@ class TestPermutation(unittest.TestCase): ...@@ -63,19 +64,44 @@ class TestPermutation(unittest.TestCase):
'entity_id': self.entity_id, 'entity_id': self.entity_id,
'seq_length': torch.tensor([57]) 'seq_length': torch.tensor([57])
} }
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id']) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(
batch, batch['asym_id'])
anchor_gt_asym = int(anchor_gt_asym) anchor_gt_asym = int(anchor_gt_asym)
anchor_pred_asym = {int(i) for i in anchor_pred_asym} anchor_pred_asym = {int(i) for i in anchor_pred_asym}
expected_anchors = {1, 2} expected_anchors = {1, 2}
expected_non_anchors = {3, 4, 5} expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors) self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors) self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set # Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym) self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors) self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self): def test_2_permutation_pentamer(self):
"""
Test the permutation results on a pentamer A2B3, in which protein A has 9 residues
and protein B has 13 residues.
Expected outputs:
Only protein A should be selected as an anchor thus, in the output list, either [(0,1), (1,0)] or [(0,0), (1,1)] are allowed
The 3 chains from protein B should ALWAYS be aligned in a way that predicted b1 to be aligned with ground truth b1, pred b2 to ground truth b2
as shown below:
predicted structure: a2 - a1 - b2 - b3 - b1
indexes in the predicted list: 0 1 2 3 4
ground truth structure: a1 - a2 - b1 - b2 - b3
indexes in the ground truth list: 0 1 2 3 4
then the 2 protein A chains are free to be aligned by either order, thus either [(0,1),(1,0)] or [(0,0),(1,1)] is valid.
However, the 3 protein B chains should be strictly aligned in the following order:
[(2,3), (3,4), (4,2)], regardless of how protein A chains are aligned.
Therefore, the only 2 correct permutations are :
[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)] and
[(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]
"""
batch = { batch = {
'asym_id': self.asym_id, 'asym_id': self.asym_id,
'sym_id': self.sym_id, 'sym_id': self.sym_id,
...@@ -85,7 +111,7 @@ class TestPermutation(unittest.TestCase): ...@@ -85,7 +111,7 @@ class TestPermutation(unittest.TestCase):
} }
batch['asym_id'] = batch['asym_id'].reshape(1, self.num_res) batch['asym_id'] = batch['asym_id'].reshape(1, self.num_res)
batch["residue_index"] = torch.tensor([self.residue_index]) batch["residue_index"] = torch.tensor([self.residue_index])
# create fake ground truth atom positions # create fake ground truth atom positions
chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37), chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3) dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10 chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10
...@@ -93,16 +119,22 @@ class TestPermutation(unittest.TestCase): ...@@ -93,16 +119,22 @@ class TestPermutation(unittest.TestCase):
chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37), chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3) dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10 chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30 chain_b3_pos = torch.matmul(torch.matmul(
# Below permutate predicted chain positions chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) # Below permutate predicted chain positions
# here the b2 chain from the ground truth is deliberately put in b1 chain's position, and predicted b3 chain to b2's position
# and predicted b1 chain to b3's position
pred_atom_position = torch.cat(
(chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
pred_atom_mask = torch.ones((1, self.num_res, 37)) pred_atom_mask = torch.ones((1, self.num_res, 37))
out = { out = {
'final_atom_positions': pred_atom_position, 'final_atom_positions': pred_atom_position,
'final_atom_mask': pred_atom_mask 'final_atom_mask': pred_atom_mask
} }
true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1) true_atom_position = torch.cat(
(chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)), true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_a_num_res, 37)), torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37)),
...@@ -111,27 +143,47 @@ class TestPermutation(unittest.TestCase): ...@@ -111,27 +143,47 @@ class TestPermutation(unittest.TestCase):
batch['all_atom_positions'] = true_atom_position batch['all_atom_positions'] = true_atom_position
batch['all_atom_mask'] = true_atom_mask batch['all_atom_mask'] = true_atom_mask
aligns, _ = compute_permutation_alignment(out, batch, aligns, per_asym_residue_index = compute_permutation_alignment(out, batch,
batch) batch)
print(f"##### aligns is {aligns}")
possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]] expected_asym_residue_index = {
wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]] 1: torch.tensor(list(range(self.chain_a_num_res))),
self.assertIn(aligns, possible_outcome) 2: torch.tensor(list(range(self.chain_a_num_res))),
self.assertNotIn(aligns, wrong_outcome) 3: torch.tensor(list(range(self.chain_b_num_res))),
4: torch.tensor(list(range(self.chain_b_num_res))),
5: torch.tensor(list(range(self.chain_b_num_res)))
}
chain_a_permutated_chain_b_permutated = [
(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)]
chain_a_not_permutated_chain_b_permutated = [
(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]
chain_a_permutated_chain_b_not_permuated = [
(0, 1), (1, 0), (2, 2), (3, 3), (4, 4)]
chain_a_not_permutated_chain_b_not_permuated = [
(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
# test on the permutation alignments
self.assertIn(aligns, [chain_a_permutated_chain_b_permutated,
chain_a_not_permutated_chain_b_permutated])
self.assertNotIn(aligns, [chain_a_permutated_chain_b_not_permuated,
chain_a_not_permutated_chain_b_not_permuated])
# test on the per_aysm_residue_index
for k, v in expected_asym_residue_index.items():
self.assertTrue(torch.equal(v, per_asym_residue_index[k]))
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self): def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325 nres_pad = 325 - 57 # suppose the cropping size is 325
batch = { batch = {
'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1), 'asym_id': self.asym_id,
'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1), 'sym_id': self.sym_id,
'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1), 'entity_id': self.entity_id,
'aatype': torch.randint(21, size=(1, 325)), 'aatype': torch.randint(21, size=(1, 57)),
'seq_length': torch.tensor([57]) 'seq_length': torch.tensor([57])
} }
batch['asym_id'] = batch['asym_id'].reshape(1, 325) batch['asym_id'] = batch['asym_id'].reshape(1, 57)
batch["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1) batch["residue_index"] = torch.tensor([self.residue_index])
# create fake ground truth atom positions # create fake ground truth atom positions
chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37), chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3) dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10 chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10
...@@ -139,39 +191,64 @@ class TestPermutation(unittest.TestCase): ...@@ -139,39 +191,64 @@ class TestPermutation(unittest.TestCase):
chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37), chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3) dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10 chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30 chain_b3_pos = torch.matmul(torch.matmul(
# Below permutate predicted chain positions chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) # Below permutate predicted chain positions
pred_atom_position = torch.cat(
(chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
pred_atom_mask = torch.ones((1, self.num_res, 37)) pred_atom_mask = torch.ones((1, self.num_res, 37))
pred_atom_position = pad_features(pred_atom_position, nres_pad, pad_dim=1) pred_atom_position = pad_features(
pred_atom_position, nres_pad, pad_dim=1)
pred_atom_mask = pad_features(pred_atom_mask, nres_pad, pad_dim=1) pred_atom_mask = pad_features(pred_atom_mask, nres_pad, pad_dim=1)
out = { out = {
'final_atom_positions': pred_atom_position, 'final_atom_positions': pred_atom_position,
'final_atom_mask': pred_atom_mask 'final_atom_mask': pred_atom_mask
} }
true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1) true_atom_position = torch.cat(
(chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)), true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_a_num_res, 37)), torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37))), dim=1) torch.ones((1, self.chain_b_num_res, 37))), dim=1)
batch['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1)
batch['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1)
# tensor_to_cuda = lambda t: t.to('cuda') batch['all_atom_positions'] = true_atom_position
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth) batch['all_atom_mask'] = true_atom_mask
# Below create a fake_input_features
fake_input_features = {
'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1),
'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1),
'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1),
'aatype': torch.randint(21, size=(1, 325)),
'seq_length': torch.tensor([57])
}
fake_input_features['asym_id'] = fake_input_features['asym_id'].reshape(
1, 325)
fake_input_features["residue_index"] = pad_features(
torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1)
fake_input_features['all_atom_positions'] = pad_features(
true_atom_position, nres_pad, pad_dim=1)
fake_input_features['all_atom_mask'] = pad_features(
true_atom_mask, nres_pad=nres_pad, pad_dim=1)
# NOTE
# batch: simulates ground_truth features
# fake_input_features: simulates the data that are going be used as input for model.forward(fake_input_features)
# out: simulates the output of model.forward(fake_input_features)
aligns, per_asym_residue_index = compute_permutation_alignment(out, aligns, per_asym_residue_index = compute_permutation_alignment(out,
batch, fake_input_features,
batch) batch)
print(f"##### aligns is {aligns}")
labels = split_ground_truth_labels(batch) labels = split_ground_truth_labels(batch)
labels = merge_labels(per_asym_residue_index, labels, aligns, labels = merge_labels(per_asym_residue_index, labels, aligns,
original_nres=batch['aatype'].shape[-1]) original_nres=batch['aatype'].shape[-1])
self.assertTrue(torch.equal(labels['residue_index'], batch['residue_index'])) self.assertTrue(torch.equal(
labels['residue_index'], batch['residue_index']))
expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos),
dim=1) dim=1)
expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1)
self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos)) self.assertTrue(torch.equal(
labels['all_atom_positions'], expected_permutated_gt_pos))
...@@ -2,13 +2,19 @@ import argparse ...@@ -2,13 +2,19 @@ import argparse
import logging import logging
import os import os
import sys import sys
import json
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
import torch import torch
import wandb
from deepspeed.utils import zero_to_fp32
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
...@@ -23,7 +29,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage ...@@ -23,7 +29,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import ( from openfold.utils.validation_metrics import (
...@@ -35,11 +40,6 @@ from openfold.utils.import_weights import ( ...@@ -35,11 +40,6 @@ from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
import_openfold_weights_ import_openfold_weights_
) )
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
...@@ -58,6 +58,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -58,6 +58,7 @@ class OpenFoldWrapper(pl.LightningModule):
self.cached_weights = None self.cached_weights = None
self.last_lr_step = -1 self.last_lr_step = -1
self.save_hyperparameters()
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
...@@ -68,14 +69,15 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -68,14 +69,15 @@ class OpenFoldWrapper(pl.LightningModule):
self.log( self.log(
f"{phase}/{loss_name}", f"{phase}/{loss_name}",
indiv_loss, indiv_loss,
on_step=train, on_epoch=(not train), logger=True, prog_bar=(loss_name == 'loss'),
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
) )
if(train): if(train):
self.log( self.log(
f"{phase}/{loss_name}_epoch", f"{phase}/{loss_name}_epoch",
indiv_loss, indiv_loss,
on_step=False, on_epoch=True, logger=True, on_step=False, on_epoch=True, logger=True, sync_dist=False,
) )
with torch.no_grad(): with torch.no_grad():
...@@ -89,7 +91,8 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -89,7 +91,8 @@ class OpenFoldWrapper(pl.LightningModule):
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
torch.mean(v), torch.mean(v),
on_step=False, on_epoch=True, logger=True prog_bar = (k == 'loss'),
on_step=False, on_epoch=True, logger=True, sync_dist=False,
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
...@@ -152,7 +155,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -152,7 +155,7 @@ class OpenFoldWrapper(pl.LightningModule):
self._log(loss_breakdown, batch, outputs, train=False) self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _): def on_validation_epoch_end(self):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
...@@ -215,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -215,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-5, eps: float = 1e-5,
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
self.model.parameters(), self.model.parameters(),
...@@ -269,35 +267,69 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -269,35 +267,69 @@ class OpenFoldWrapper(pl.LightningModule):
self.model, jax_path, version=model_version self.model, jax_path, version=model_version
) )
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
_DS_CHECKPOINT_VERSION = 2 # based on manual parsing of checkpoint files
state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
return torch.load(state_file)
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed, workers=True)
is_low_precision = args.precision in [
"bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]
config = model_config( config = model_config(
args.config_preset, args.config_preset,
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=is_low_precision,
) )
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt): if args.resume_from_ckpt:
if(os.path.isdir(args.resume_from_ckpt)): if args.resume_model_weights_only:
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) # Load the checkpoint
else: if os.path.isdir(args.resume_from_ckpt):
sd = torch.load(args.resume_from_ckpt) sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
# Process the state dict
if 'module' in sd:
sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=sd)
elif 'state_dict' in sd:
import_openfold_weights_(
model=model_module, state_dict=sd['state_dict'])
else:
# Loading from pre-trained model
sd = {'model.'+k: v for k, v in sd.items()}
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
else: # Loads a checkpoint to start from a specific time step
if os.path.isdir(args.resume_from_ckpt):
sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step']) last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step) model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...") logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)): if args.resume_from_jax_params:
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()}
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
...@@ -355,7 +387,20 @@ def main(args): ...@@ -355,7 +387,20 @@ def main(args):
callbacks.append(lr_monitor) callbacks.append(lr_monitor)
loggers = [] loggers = []
is_rank_zero = args.mpi_plugin and (int(os.environ.get("PMI_RANK")) == 0)
if(args.wandb): if(args.wandb):
if args.mpi_plugin and is_rank_zero:
wandb_init_dict = dict(
name=args.experiment_name,
project=args.wandb_project,
id=args.wandb_id,
dir=args.output_dir,
resume="allow",
anonymous=None,
entity=args.wandb_entity
)
wandb.run = wandb.init(**wandb_init_dict)
wdb_logger = WandbLogger( wdb_logger = WandbLogger(
name=args.experiment_name, name=args.experiment_name,
save_dir=args.output_dir, save_dir=args.output_dir,
...@@ -365,32 +410,39 @@ def main(args): ...@@ -365,32 +410,39 @@ def main(args):
) )
loggers.append(wdb_logger) loggers.append(wdb_logger)
cluster_environment = MPIEnvironment() if args.mpi_plugin else None
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin( strategy = DeepSpeedStrategy(
config=args.deepspeed_config_path, config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
) )
if(args.wandb): if(args.wandb and is_rank_zero):
wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py") wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1: elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPStrategy(find_unused_parameters=False,
cluster_environment=cluster_environment)
else: else:
strategy = None strategy = None
if(args.wandb): if(args.wandb and is_rank_zero):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt" freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}") os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}") wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args( trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps',
args, 'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
default_root_dir=args.output_dir, trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
strategy=strategy, trainer_args.update({
callbacks=callbacks, 'default_root_dir': args.output_dir,
logger=loggers, 'strategy': strategy,
) 'callbacks': callbacks,
'logger': loggers,
})
trainer = pl.Trainer(**trainer_args)
if(args.resume_model_weights_only):
if (args.resume_model_weights_only):
ckpt_path = None ckpt_path = None
else: else:
ckpt_path = args.resume_from_ckpt ckpt_path = args.resume_from_ckpt
...@@ -595,23 +647,42 @@ if __name__ == "__main__": ...@@ -595,23 +647,42 @@ if __name__ == "__main__":
"--distillation_alignment_index_path", type=str, default=None, "--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions." help="Distillation alignment index. See the README for instructions."
) )
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
# Disable the initial validation pass )
parser.set_defaults( parser.add_argument(
num_sanity_val_steps=0, "--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
) )
parser.add_argument("--mpi_plugin", action="store_true", default=False,
# Remove some buggy/redundant arguments introduced by the Trainer help="Whether to use MPI for parallele processing")
remove_arguments(
parser, trainer_group = parser.add_argument_group(
[ 'Arguments to pass to PyTorch Lightning Trainer')
"--accelerator", trainer_group.add_argument(
"--resume_from_checkpoint", "--num_nodes", type=int, default=1,
"--reload_dataloaders_every_epoch", )
"--reload_dataloaders_every_n_epochs", trainer_group.add_argument(
] "--precision", type=str, default='bf16',
) help='Sets precision, lower precision improves runtime performance.',
)
trainer_group.add_argument(
"--max_epochs", type=int, default=1,
)
trainer_group.add_argument(
"--log_every_n_steps", type=int, default=25,
)
trainer_group.add_argument(
"--flush_logs_every_n_steps", type=int, default=5,
)
trainer_group.add_argument(
"--num_sanity_val_steps", type=int, default=0,
)
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args() args = parser.parse_args()
...@@ -626,7 +697,5 @@ if __name__ == "__main__": ...@@ -626,7 +697,5 @@ if __name__ == "__main__":
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment