Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips all required data for AlphaFold.
#
# Usage: bash download_all_data.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs.
if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]]
then
echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized."
exit 1
fi
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
echo "Downloading AlphaFold parameters..."
bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}"
if [[ "${DOWNLOAD_MODE}" = reduced_dbs ]] ; then
echo "Downloading Small BFD..."
bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}"
else
echo "Downloading BFD..."
bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}"
fi
echo "Downloading MGnify..."
bash "${SCRIPT_DIR}/download_mgnify.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB70..."
bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniclust30..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded."
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the AlphaFold parameters.
#
# Usage: bash download_alphafold_params.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}" --preserve-permissions
rm "${ROOT_DIR}/${BASENAME}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the BFD database for AlphaFold.
#
# Usage: bash download_bfd.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/bfd"
# Mirror of:
# https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz.
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}"
rm "${ROOT_DIR}/${BASENAME}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the MGnify database for AlphaFold.
#
# Usage: bash download_mgnify.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/mgnify"
# Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${BASENAME}"
popd
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB70 database for AlphaFold.
#
# Usage: bash download_pdb70.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb70"
SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c --check-certificate=false "${SOURCE_URL}" --dir="${ROOT_DIR}"
tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}"
rm "${ROOT_DIR}/${BASENAME}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and flattens the PDB database for AlphaFold.
#
# Usage: bash download_pdb_mmcif.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
if ! command -v rsync &> /dev/null ; then
echo "Error: rsync could not be found. Please install rsync."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_mmcif"
RAW_DIR="${ROOT_DIR}/raw"
MMCIF_DIR="${ROOT_DIR}/mmcif_files"
echo "Running rsync to fetch all mmCIF files (note that the rsync progress estimate might be inaccurate)..."
mkdir --parents "${RAW_DIR}"
rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 \
rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ \
"${RAW_DIR}"
echo "Unzipping all mmCIF files..."
find "${RAW_DIR}/" -type f -iname "*.gz" -exec gunzip {} +
echo "Flattening all mmCIF files..."
mkdir --parents "${MMCIF_DIR}"
find "${RAW_DIR}" -type d -empty -delete # Delete empty directories.
for subdir in "${RAW_DIR}"/*; do
mv "${subdir}/"*.cif "${MMCIF_DIR}"
done
# Delete empty download directory structure.
find "${RAW_DIR}" -type d -empty -delete
aria2c "ftp://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat" --dir="${ROOT_DIR}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres"
SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the Small BFD database for AlphaFold.
#
# Usage: bash download_small_bfd.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/small_bfd"
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${BASENAME}"
popd
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the Uniclust30 database for AlphaFold.
#
# Usage: bash download_uniclust30.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniclust30"
# Mirror of:
# http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/uniclust30_2018_08_hhsuite.tar.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/uniclust30_2018_08_hhsuite.tar.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}"
rm "${ROOT_DIR}/${BASENAME}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and merges the SwissProt and TrEMBL databases for
# AlphaFold-Multimer.
#
# Usage: bash download_uniprot.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniprot"
TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz"
TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}")
TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}"
SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}")
SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}"
mkdir --parents "${ROOT_DIR}"
aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}"
aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${TREMBL_BASENAME}"
gunzip "${ROOT_DIR}/${SPROT_BASENAME}"
# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up.
cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}"
mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta"
#rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}"
popd
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the UniRef90 database for AlphaFold.
#
# Usage: bash download_uniref90.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniref90"
SOURCE_URL="ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${BASENAME}"
popd
import os
import json
import sys
from tqdm import tqdm
from multiprocessing import Pool
import glob
import pickle
import numpy as np
from functools import partial
import gzip
restypes = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap))
)
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = np.array(
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE, dtype=np.int8
)
data_dir = sys.argv[1]
data_type = sys.argv[2]
output_dir = sys.argv[3]
prefix = sys.argv[4]
feature_dir = "{}/{}_features/".format(data_dir, data_type)
label_dir = "{}/{}_labels/".format(data_dir, data_type)
feature_files = glob.glob(feature_dir + "*")
cluster_size = json.load(
open(os.path.join(data_dir, "{}_cluster_size.json".format(data_type)))
)
if data_type == "pdb":
multi_label = json.load(open(os.path.join(data_dir, "pdb_multi_label.json")))
else:
multi_label = None
new_sample_weight = {}
new_multi_label = {}
def __load_from_file__(path):
if path.endswith(".pkl"):
return pickle.load(open(path, "rb"))
else:
return pickle.load(gzip.open(path, "rb"))
def get_sample_weight(len, cs):
p1 = max(min(len, 512), 256) / 512
p2 = len**2 / 1024
return min(p1, p2) / cs
def check_one_file(file):
t = os.path.split(file)[-1].split(".")[0]
raw_feature = __load_from_file__(file)
seq_len = raw_feature["aatype"].shape[0]
aatype = np.argmax(raw_feature["aatype"], axis=-1)
msa_aatype = np.array(
[MAP_HHBLITS_AATYPE_TO_OUR_AATYPE[ii] for ii in raw_feature["msa"][0]]
)
if not (aatype == msa_aatype).all():
return t, None, None
_, counts = np.unique(aatype, return_counts=True)
freqs = counts.astype(np.float32) / seq_len
max_freq = np.max(freqs)
labels = []
def load_and_check_label(label_t, seq_len):
label_filename = os.path.join(label_dir, "{}.label.pkl.gz".format(label_t))
if os.path.isfile(label_filename):
raw_label = __load_from_file__(label_filename)
label_aatype = raw_label["aatype_index"]
label_seq_len = raw_label["all_atom_positions"].shape[0]
resolution = raw_label["resolution"].reshape(1)[0]
if (
label_seq_len == seq_len
and (aatype == label_aatype).all()
and resolution < 9
):
return True
return False
if multi_label is None or t not in multi_label:
if load_and_check_label(t, seq_len):
labels.append(t)
else:
for label_t in multi_label[t]:
if load_and_check_label(label_t, seq_len):
labels.append(label_t)
if len(labels) > 0 and t in cluster_size and max_freq < 0.8:
sample_weight = get_sample_weight(seq_len, cluster_size[t])
return t, sample_weight, labels
else:
return t, None, None
file_cnt = len(feature_files)
filter_cnt = 0
error_features = []
error_labels = []
with Pool(96) as pool:
for ret in tqdm(pool.imap(check_one_file, feature_files), total=file_cnt):
t, sw, ll = ret
if sw is not None:
new_sample_weight[t] = sw
new_multi_label[t] = ll
if multi_label is not None and len(ll) < len(multi_label[t]):
for x in multi_label[t]:
if x not in ll:
error_labels.append(x)
else:
if len(ll) <= 0:
error_labels.append(t)
else:
error_features.append(t)
filter_cnt += 1
print(len(error_features), len(error_labels))
def write_list_to_file(a, file):
with open(file, "w") as output:
for x in a:
output.write(str(x) + "\n")
write_list_to_file(
error_features, "{}/{}_error_features.txt".format(output_dir, data_type)
)
write_list_to_file(error_labels, "{}/{}_error_labels.txt".format(output_dir, data_type))
if data_type == "pdb":
json.dump(
new_sample_weight,
open("{}/{}_train_sample_weight.json".format(output_dir, prefix), "w"),
indent=4,
)
json.dump(
new_multi_label,
open("{}/{}_train_multi_label.json".format(output_dir, prefix), "w"),
indent=4,
)
else:
json.dump(
new_sample_weight,
open("{}/{}_sd_train_sample_weight.json".format(output_dir, prefix), "w"),
indent=4,
)
import os
import json
import sys
from tqdm import tqdm
from multiprocessing import Pool
import glob
import pickle
import numpy as np
from functools import partial
import gzip
restypes = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap))
)
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = np.array(
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE, dtype=np.int8
)
data_dir = sys.argv[1]
data_type = sys.argv[2]
output_dir = sys.argv[3]
prefix = sys.argv[4]
feature_dir = "{}/{}_features/".format(data_dir, data_type)
label_dir = "{}/{}_labels/".format(data_dir, data_type)
uniprot_dir = "{}/{}_uniprots/".format(data_dir, data_type)
feature_files = glob.glob(feature_dir + "*")
cluster_size = json.load(
open(os.path.join(data_dir, "{}_cluster_size.json".format(data_type)))
)
if data_type == "pdb":
multi_label = json.load(open(os.path.join(data_dir, "pdb_multi_label.json")))
pdb_assembly = json.load(open(os.path.join(data_dir, "pdb_assembly.json")))
else:
multi_label = None
pdb_assembly = None
new_sample_weight = {}
new_multi_label = {}
def __load_from_file__(path):
if path.endswith(".pkl"):
return pickle.load(open(path, "rb"))
else:
return pickle.load(gzip.open(path, "rb"))
def get_sample_weight(len, cs):
return 1.0 / cs
def check_one_file(file):
t = os.path.split(file)[-1].split(".")[0]
raw_feature = __load_from_file__(file)
seq_len = raw_feature["aatype"].shape[0]
aatype = np.argmax(raw_feature["aatype"], axis=-1).astype(np.int64)
msa_aatype = np.array(
[MAP_HHBLITS_AATYPE_TO_OUR_AATYPE[ii] for ii in raw_feature["msa"][0]]
).astype(np.int64)
if not (aatype == msa_aatype).all():
return t, None, None
_, counts = np.unique(aatype, return_counts=True)
freqs = counts.astype(np.float32) / seq_len
max_freq = np.max(freqs)
labels = []
def load_and_check_label(label_t, seq_len):
label_filename = os.path.join(label_dir, "{}.label.pkl.gz".format(label_t))
if os.path.isfile(label_filename):
raw_label = __load_from_file__(label_filename)
label_aatype = raw_label["aatype_index"].astype(np.int64)
label_seq_len = raw_label["all_atom_positions"].shape[0]
resolution = raw_label["resolution"].reshape(1)[0]
if (
label_seq_len == seq_len
and (aatype == label_aatype).all()
and resolution < 9
):
return True
return False
def load_and_check_uniprot(label_t, seq_len):
label_filename = os.path.join(uniprot_dir, "{}.uniprot.pkl.gz".format(label_t))
if os.path.isfile(label_filename):
raw_label = __load_from_file__(label_filename)
uniprot_seq_len = raw_label["msa"].shape[-1]
msa_aatype = np.array(
[MAP_HHBLITS_AATYPE_TO_OUR_AATYPE[ii] for ii in raw_label["msa"][0]]
).astype(np.int64)
if seq_len == uniprot_seq_len and (msa_aatype == aatype).all():
return True
return False
if multi_label is None or t not in multi_label:
if load_and_check_label(t, seq_len):
labels.append(t)
else:
for label_t in multi_label[t]:
if load_and_check_label(label_t, seq_len):
labels.append(label_t)
if len(labels) > 0 and t in cluster_size and (max_freq < 0.8 or data_type == "pdb"):
if load_and_check_uniprot(t, seq_len):
sample_weight = get_sample_weight(seq_len, cluster_size[t])
return t, sample_weight, labels
else:
return t, "uniprot", False
else:
return t, None, None
file_cnt = len(feature_files)
filter_cnt = 0
error_features = []
error_labels = []
error_uniprots = []
with Pool(64) as pool:
for ret in tqdm(pool.imap(check_one_file, feature_files), total=file_cnt):
t, sw, ll = ret
if sw == "uniprot":
error_uniprots.append(t)
elif sw is not None:
new_sample_weight[t] = sw
new_multi_label[t] = ll
if multi_label is not None and len(ll) < len(multi_label[t]):
for x in multi_label[t]:
if x not in ll:
error_labels.append(x)
else:
if len(ll) <= 0:
error_labels.append(t)
else:
error_features.append(t)
filter_cnt += 1
def _inverse_map(mapping):
inverse_mapping = {}
for ent, refs in mapping.items():
for ref in refs:
if ref in inverse_mapping: # duplicated ent for this ref.
ent_2 = inverse_mapping[ref]
assert (
ent == ent_2
), f"multiple entities ({ent_2}, {ent}) exist for reference {ref}."
inverse_mapping[ref] = ent
return inverse_mapping
def get_assembly(canon_chain_map):
pdb_chains = {}
for chain in canon_chain_map:
pdb = chain.split("_")[0]
if pdb not in pdb_chains:
pdb_chains[pdb] = []
pdb_chains[pdb].append(chain)
return pdb_chains
pdb_map = get_assembly(_inverse_map(new_multi_label))
def check_pdb(pdb):
chains = pdb_map[pdb]
len_chains = len(chains)
if pdb not in pdb_assembly:
if len(chains) > 2:
print("unknown assembly", pdb)
return pdb, False
else:
chains = set(chains)
pdb_chains = pdb_assembly[pdb]["chains"]
new_chains = []
complete = True
for chain in pdb_chains:
chain_name = pdb + "_" + chain
if chain_name not in chains:
complete = False
break
new_chains.append(chain_name)
if not complete:
print("not complete", pdb, chains, pdb_chains)
return pdb, False
chains = new_chains
aatypes = []
for chain in chains:
label_filename = os.path.join(label_dir, "{}.label.pkl.gz".format(chain))
raw_label = __load_from_file__(label_filename)
label_aatype = raw_label["aatype_index"].astype(np.int64).reshape(-1)
aatypes.append(label_aatype)
aatype = np.concatenate(aatypes).reshape(-1)
_, counts = np.unique(aatype, return_counts=True)
freqs = counts.astype(np.float32) / aatype.shape[0]
max_freq = np.max(freqs)
if max_freq < 0.8:
return pdb, True
else:
return pdb, False
error_pdbs = []
if data_type == "pdb":
pdbs = pdb_map.keys()
cnt_pdbs = len(pdbs)
available_pdbs = {}
with Pool(64) as pool:
for ret in tqdm(pool.imap(check_pdb, pdbs), total=cnt_pdbs):
if ret[1]:
available_pdbs[ret[0]] = 1
else:
error_pdbs.append(ret[0])
new_multi_label2 = {}
for key in new_multi_label:
labels = []
for v in new_multi_label[key]:
if v.split("_")[0] in available_pdbs:
labels.append(v)
if labels:
new_multi_label2[key] = labels
new_multi_label = new_multi_label2
print(len(error_features), len(error_labels), len(error_uniprots), len(error_pdbs))
def write_list_to_file(a, file):
with open(file, "w") as output:
for x in a:
output.write(str(x) + "\n")
write_list_to_file(error_pdbs, "{}/{}_error_pdbs.txt".format(output_dir, data_type))
write_list_to_file(
error_features, "{}/{}_error_features.txt".format(output_dir, data_type)
)
write_list_to_file(error_labels, "{}/{}_error_labels.txt".format(output_dir, data_type))
write_list_to_file(
error_uniprots, "{}/{}_error_uniprots.txt".format(output_dir, data_type)
)
if data_type == "pdb":
json.dump(
new_sample_weight,
open("{}/{}_train_sample_weight.json".format(output_dir, prefix), "w"),
indent=4,
)
json.dump(
new_multi_label,
open("{}/{}_train_multi_label.json".format(output_dir, prefix), "w"),
indent=4,
)
else:
json.dump(
new_sample_weight,
open("{}/{}_sd_train_sample_weight.json".format(output_dir, prefix), "w"),
indent=4,
)
import os, sys
import shlex
import glob
import json
from tqdm import tqdm
from multiprocessing import Pool
from unifold.msa.mmcif import parse
rot_keys = """_pdbx_struct_oper_list.matrix[1][1]
_pdbx_struct_oper_list.matrix[1][2]
_pdbx_struct_oper_list.matrix[1][3]
_pdbx_struct_oper_list.matrix[2][1]
_pdbx_struct_oper_list.matrix[2][2]
_pdbx_struct_oper_list.matrix[2][3]
_pdbx_struct_oper_list.matrix[3][1]
_pdbx_struct_oper_list.matrix[3][2]
_pdbx_struct_oper_list.matrix[3][3]""".split(
"\n"
)
tran_keys = """_pdbx_struct_oper_list.vector[1]
_pdbx_struct_oper_list.vector[2]
_pdbx_struct_oper_list.vector[3]""".split(
"\n"
)
def process_block_to_dict(content):
ret = {}
lines = content.split("\n")
if lines[0] == "loop_":
keys = []
values = []
last_val = []
for line in lines[1:]:
line = line.strip()
if line.startswith("_"):
keys.append(line)
else:
num_key = len(keys)
cur_vals = shlex.split(line)
last_val.extend(cur_vals)
assert len(last_val) <= num_key, (
num_key,
len(last_val),
last_val,
cur_vals,
)
if len(last_val) == num_key:
values.append(last_val)
last_val = []
if last_val:
assert len(last_val) == num_key
values.append(last_val)
for i, k in enumerate(keys):
ret[k] = [vals[i] for vals in values]
else:
last_val = []
for line in lines:
t = shlex.split(line)
last_val.extend(t)
if len(last_val) == 2:
ret[last_val[0]] = [last_val[1]]
last_val = []
if last_val:
assert len(last_val) == 2
ret[last_val[0]] = [last_val[1]]
return ret
def get_transform(data, idx):
idx = int(idx) - 1
rot = []
for key in rot_keys:
rot.append(float(data[key][idx]))
trans = []
for key in tran_keys:
trans.append(float(data[key][idx]))
rot, trans = tuple(rot), tuple(trans)
if rot == (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) and trans == (
0.0,
0.0,
0.0,
):
return "I"
else:
return rot, trans
def parse_assembly(mmcif_path):
name = os.path.split(mmcif_path)[-1].split(".")[0]
with open(mmcif_path) as f:
mmcif_string = f.read()
mmcif_lines = mmcif_string.split("\n")
parse_result = parse(file_id="", mmcif_string=mmcif_string)
if "No protein chains found in this file." in parse_result.errors.values():
return name, [], [], [], "no protein"
mmcif_obj = parse_result.mmcif_object
if mmcif_obj is None:
print(name, parse_result.errors)
return name, [], [], [], "parse error"
mmcif_to_author_chain_id = mmcif_obj.mmcif_to_author_chain_id
valid_chains = mmcif_obj.valid_chains.keys()
valid_chains = set(valid_chains) # valid chains is not auth_id
new_section = False
is_loop = False
cur_lines = []
assembly = None
assembly_gen = None
oper = None
error_type = None
try:
for line in mmcif_lines:
line = line.strip().replace(";", "")
if not line:
continue
if line == "#":
cur_str = "\n".join(cur_lines)
if "revision" in cur_str:
continue
if "_pdbx_struct_assembly.id" in cur_str:
assembly = process_block_to_dict(cur_str)
if "_pdbx_struct_assembly_gen.assembly_id" in cur_str:
assembly_gen = process_block_to_dict(cur_str)
if "_pdbx_struct_oper_list.id" in cur_str:
oper = process_block_to_dict(cur_str)
cur_lines = []
else:
cur_lines.append(line)
except Exception as e:
print(name, e)
return name, [], [], [], "parse"
if not (assembly is not None and assembly_gen is not None and oper is not None):
return name, [], [], [], "miss"
try:
counts = assembly["_pdbx_struct_assembly.oligomeric_count"]
asym_id = assembly_gen["_pdbx_struct_assembly_gen.assembly_id"]
op_idx = assembly_gen["_pdbx_struct_assembly_gen.oper_expression"]
assembly_chains = assembly_gen["_pdbx_struct_assembly_gen.asym_id_list"]
chains = []
chains_ops = []
for i, j in enumerate(asym_id):
if j == "1":
sss = (
op_idx[i]
.replace("(", "")
.replace(")", "")
.replace("'", "")
.replace('"', "")
)
if "-" in sss:
s, t = sss.split("-")
indices = range(int(s), int(t) + 1)
else:
indices = sss.split(",")
for idx in indices:
chains.append(assembly_chains[i].split(","))
chains_ops.append(get_transform(oper, idx))
len_ops = len(chains)
count = int(counts[0])
all_chains = []
all_chains_ops = []
all_chains_label = []
for i, cur_chains in enumerate(chains):
for chain in cur_chains:
if chain not in valid_chains:
continue
auth_chain = mmcif_to_author_chain_id[chain]
all_chains_label.append(chain)
all_chains.append(auth_chain)
all_chains_ops.append(chains_ops[i])
return name, all_chains, all_chains_label, all_chains_ops, "success"
except Exception as e:
print(name, e)
return name, [], [], [], "index"
input_dir = sys.argv[1]
output_file = sys.argv[2]
input_files = glob.glob(input_dir + "*.cif")
file_cnt = len(input_files)
meta_dict = {}
failed = []
with Pool(64) as pool:
for ret in tqdm(pool.imap(parse_assembly, input_files), total=file_cnt):
name, all_chains, all_chains_label, all_chains_ops, error_type = ret
if all_chains:
meta_dict[name] = {}
meta_dict[name]["chains"] = all_chains
meta_dict[name]["chains_label"] = all_chains_label
meta_dict[name]["opers"] = all_chains_ops
else:
failed.append(name + " " + error_type)
json.dump(meta_dict, open(output_file, "w"), indent=2)
def write_list_to_file(a, file):
with open(file, "w") as output:
for x in a:
output.write(str(x) + "\n")
write_list_to_file(failed, "failed_mmcif.txt")
import os, sys
import shlex
import glob
import json
from tqdm import tqdm
from multiprocessing import Pool
from unifold.msa.mmcif import parse
import gzip
def parse_assembly(mmcif_path):
name = os.path.split(mmcif_path)[-1].split(".")[0]
if mmcif_path.endswith(".gz"):
with gzip.open(mmcif_path, "rb") as f:
mmcif_string = f.read().decode()
mmcif_lines = mmcif_string.split("\n")
else:
with open(mmcif_path, "rb") as f:
mmcif_string = f.read()
mmcif_lines = mmcif_string.split("\n")
parse_result = parse(file_id="", mmcif_string=mmcif_string)
if "No protein chains found in this file." in parse_result.errors.values():
return name, None
mmcif_obj = parse_result.mmcif_object
if mmcif_obj is None:
print(name, parse_result.errors)
return name, None
mmcif_to_author_chain_id = mmcif_obj.mmcif_to_author_chain_id
valid_chains = mmcif_obj.valid_chains.keys()
valid_chains = list(set(valid_chains)) # valid chains is not auth_id
return name, {"to_auth_id": mmcif_to_author_chain_id, "valid_chains": valid_chains}
input_dir = sys.argv[1]
output_file = sys.argv[2]
input_files = glob.glob(input_dir + "*.cif") + glob.glob(input_dir + "*.cif.gz")
file_cnt = len(input_files)
meta_dict = {}
failed = []
with Pool(64) as pool:
for ret in tqdm(pool.imap(parse_assembly, input_files), total=file_cnt):
name, res = ret
if res:
meta_dict[name] = res
json.dump(meta_dict, open(output_file, "w"), indent=2)
import os, sys
import glob
import json
from tqdm import tqdm
from multiprocessing import Pool
import requests
import time
input_dir = sys.argv[1]
output_dir = sys.argv[2]
input_files = glob.glob(input_dir + "*.cif") + glob.glob(input_dir + "*.cif.gz")
os.system("mkdir -p " + output_dir)
pdb_chain_mapper = json.load(open(sys.argv[3]))
rot_keys = [
"matrix11",
"matrix12",
"matrix13",
"matrix21",
"matrix22",
"matrix23",
"matrix31",
"matrix32",
"matrix33",
]
trans_keys = ["vector1", "vector2", "vector3"]
def get_oper(cont):
cont = cont["pdbx_struct_oper_list"]
ret = {}
for c in cont:
id = c["id"]
rot = []
trans = []
for k in rot_keys:
rot.append(c[k])
for k in trans_keys:
trans.append(c[k])
rot, trans = tuple(rot), tuple(trans)
if rot == (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) and trans == (
0.0,
0.0,
0.0,
):
ret[id] = "I"
else:
ret[id] = (rot, trans)
ret["I"] = "I"
return ret
# refer to https://data.rcsb.org/redoc/index.html
def get_pdb_meta_info(mmcif_path):
name = os.path.split(mmcif_path)[-1].split(".")[0]
url = f"https://data.rcsb.org/rest/v1/core/assembly/{name}/1"
max_try_time = 10
for i in range(max_try_time):
try:
out_path = os.path.join(output_dir, name + ".json")
load = False
if os.path.isfile(out_path):
cont = json.load(open(out_path, "r"))
load = True
else:
r = requests.get(url)
cont = json.loads(r.text)
json.dump(cont, open(out_path, "w"))
load = r.ok
if load:
if "rcsb_struct_symmetry" not in cont:
break
cur_mapper = pdb_chain_mapper[name]["to_auth_id"]
for tt in cont["rcsb_struct_symmetry"]:
if tt["kind"] == "Global Symmetry":
symbol = tt["symbol"]
stoi = tt["stoichiometry"]
all_opers = get_oper(cont)
chains = []
opers = []
for c in tt["clusters"]:
for m in c["members"]:
chain_id = cur_mapper[m["asym_id"]]
if "pdbx_struct_oper_list_ids" in m:
for op_idx in m["pdbx_struct_oper_list_ids"]:
chains.append(chain_id)
opers.append(all_opers[op_idx])
else:
chains.append(chain_id)
opers.append("I")
return name, {
"symbol": symbol,
"stoi": stoi,
"chains": chains,
"opers": opers,
}
break
elif cont["status"] == "404":
break
except Exception as e:
print(name, e)
time.sleep(2)
return name, None
file_cnt = len(input_files)
meta_dict = {}
# failed = []
with Pool(64) as pool:
for ret in tqdm(pool.imap(get_pdb_meta_info, input_files), total=file_cnt):
name, res = ret
if res:
meta_dict[name] = res
json.dump(meta_dict, open("pdb_assembly.json", "w"), indent=2)
# Copyright 2022 DP Technology
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from dataclasses import dataclass
from functools import partial
import numpy as np
import torch
from typing import Union, List
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
def reshape_weight(x):
len_shape = len(x.shape)
if len_shape == 2:
return x.transpose(-1, -2)
elif len_shape == 1:
return x.reshape(-1, 1)
# With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: reshape_weight(w)
)
LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
)
LinearMHAOutputWeight = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
Other = partial(lambda w: w)
def __init__(self, fn):
self.transformation = fn
@dataclass
class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
def _process_translations_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {
(prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict(
v, top_layer=False
).items()
}
flat.update(sub_flat)
else:
k = "/" + k if not top_layer else k
flat[k] = v
return flat
def stacked(param_dict_list, out=None):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if out is None:
out = {}
template = param_dict_list[0]
for k, _ in template.items():
v = [d[k] for d in param_dict_list]
if type(v[0]) is dict:
out[k] = {}
stacked(v, out=out[k])
elif type(v[0]) is Param:
stacked_param = Param(
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
)
out[k] = stacked_param
return out
def assign(translation_dict, orig_weights):
for k, param in translation_dict.items():
with torch.no_grad():
weights = torch.as_tensor(orig_weights[k])
ref, param_type = param.param, param.param_type
if param.stacked:
weights = torch.unbind(weights, 0)
else:
weights = [weights]
ref = [ref]
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
p.copy_(w)
except:
print(k)
print(ref[0].shape)
print(weights[0].shape)
raise
def import_jax_weights_(model, npz_path, version="model_1"):
is_multimer = False
if version in ["multimer_af2"]:
is_multimer = True
data = np.load(npz_path, allow_pickle=True)
if 'arr_0' in data:
data = data['arr_0'].flat[0]
global _NPZ_KEY_PREFIX
_NPZ_KEY_PREFIX = "unifold/unifold_iteration/"
keys = list(data.keys())
for key in keys:
for subkey in data[key]:
data[key + '//' + subkey] = data[key][subkey]
del data[key]
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LinearLeftParams = lambda l,index: {
"weights": LinearWeight(l.weight[:index,:]),
"bias": LinearBias(l.bias[:index]),
}
LinearRightParams = lambda l,index: {
"weights": LinearWeight(l.weight[index:,:]),
"bias": LinearBias(l.bias[index:]),
}
LinearMHAParams = lambda l: {
"weights": LinearWeightMHA(l.weight),
"bias": LinearBiasMHA(l.bias),
}
LinearNoBiasParams = lambda l: {
"weights": LinearWeight(l.weight),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
}
AttentionParams = lambda att: {
"query_w": LinearWeightMHA(att.linear_q.weight),
"key_w": LinearWeightMHA(att.linear_k.weight),
"value_w": LinearWeightMHA(att.linear_v.weight),
"output_w": Param(
att.linear_o.weight,
param_type=ParamType.LinearMHAOutputWeight,
),
"output_b": LinearBias(att.linear_o.bias),
}
AttentionGatedParams = lambda att: dict(
**AttentionParams(att),
**{
"gating_w": LinearWeightMHA(att.linear_g.weight),
"gating_b": LinearBiasMHA(att.linear_g.bias),
},
)
GlobalAttentionParams = lambda att: dict(
AttentionGatedParams(att),
key_w=LinearWeight(att.linear_k.weight),
value_w=LinearWeight(att.linear_v.weight),
)
TriAttParams = lambda tri_att: {
"query_norm": LayerNormParams(tri_att.layer_norm),
"feat_2d_weights": LinearWeight(tri_att.linear.weight),
"attention": AttentionGatedParams(tri_att.mha),
}
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearLeftParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
"right_projection": LinearRightParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
"left_gate": LinearLeftParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
"right_gate": LinearRightParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearRightParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
"right_projection": LinearLeftParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
"left_gate": LinearRightParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
"right_gate": LinearLeftParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
"transition1": LinearParams(pt.linear_1),
"transition2": LinearParams(pt.linear_2),
}
MSAAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": AttentionGatedParams(matt.mha),
}
MSAColAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": AttentionGatedParams(matt.mha),
}
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention),
}
MSAAttPairBiasParams = lambda matt: dict(
**MSAAttParams(matt),
**{
"feat_2d_norm": LayerNormParams(matt.layer_norm_z),
"feat_2d_weights": LinearWeight(matt.linear_z.weight),
},
)
IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
if is_multimer:
MultimerIPAParams = lambda ipa: {
"q_scalar_projection": {"weights": LinearWeightMHA(ipa.linear_q.weight)},
"k_scalar_projection": {"weights": LinearWeightMHA(ipa.linear_k.weight)},
"v_scalar_projection": {"weights": LinearWeightMHA(ipa.linear_v.weight)},
"q_point_projection": {"point_projection": LinearMHAParams(ipa.linear_q_points)},
"k_point_projection": {"point_projection": LinearMHAParams(ipa.linear_k_points)},
"v_point_projection": {"point_projection": LinearMHAParams(ipa.linear_v_points)},
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
TemplatePairBlockParams = lambda b: {
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"pair_transition": PairTransitionParams(b.pair_transition),
}
MSATransitionParams = lambda m: {
"input_layer_norm": LayerNormParams(m.layer_norm),
"transition1": LinearParams(m.linear_1),
"transition2": LinearParams(m.linear_2),
}
OuterProductMeanParams = lambda o: {
"layer_norm_input": LayerNormParams(o.layer_norm),
"left_projection": LinearParams(o.linear_1),
"right_projection": LinearParams(o.linear_2),
"output_w": LinearWeightOPM(o.linear_out.weight),
"output_b": LinearBias(o.linear_out.bias),
}
def EvoformerBlockParams(b, is_extra_msa=False):
if is_extra_msa:
col_att_name = "msa_column_global_attention"
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else:
col_att_name = "msa_column_attention"
msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"pair_transition": PairTransitionParams(b.pair_transition),
}
return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: {
"invariant_point_attention": IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
"transition_2": LinearParams(sm.transition.layers[0].linear_3),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
if is_multimer:
MultimerFoldIterationParams = lambda sm: {
"invariant_point_attention": MultimerIPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
"transition_2": LinearParams(sm.transition.layers[0].linear_3),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"quat_rigid": {"rigid": LinearParams(sm.bb_update.linear)},
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
############################
# translations dict overflow
############################
tps_blocks_params = None
template_pair_ln = None
template_angle_emb = None
template_angle_proj = None
if model.template_pair_stack is not None:
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
template_pair_ln = LayerNormParams(model.template_pair_stack.layer_norm)
template_angle_emb = LinearParams(model.template_angle_embedder.linear_1)
template_angle_proj = LinearParams(model.template_angle_embedder.linear_2)
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"pair_activiations": LinearParams(
model.input_embedder.linear_relpos
),
"template_embedding": {
"single_template_embedding": {
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": template_pair_ln,
},
# "attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": template_angle_emb,
"template_projection": template_angle_proj,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": MultimerFoldIterationParams(model.structure_module) if is_multimer else FoldIterationParams(model.structure_module)
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
no_temp = version in ["model_3_af2", "model_4_af2", "model_5_af2"]
if no_temp:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if "template_" in k:
evo_dict.pop(k)
if "_ptm" in version:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.pae.linear)
}
if is_multimer:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.pae.linear)
}
# fix rel-pos embedding
del translations["evoformer"]["pair_activiations"]
translations["evoformer"]["~_relative_encoding"] = {}
translations["evoformer"]["~_relative_encoding"]["position_activations"] = LinearParams(
model.input_embedder.linear_relpos
)
for i in range(8):
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_{}".format(i)] = LinearParams(
model.template_pair_embedder.linear[i]
)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_8"] = LinearParams(
model.template_pair_embedder.z_linear
)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["query_embedding_norm"] = LayerNormParams(
model.template_pair_embedder.z_layer_norm
)
del translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_stack"]
translations["evoformer"]["template_embedding"]["output_linear"] = LinearParams(
model.template_proj.output_linear
)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_embedding_iteration"] = tps_blocks_params
else:
if not no_temp:
translations["evoformer"]["template_embedding"]["single_template_embedding"]["embedding2d"] = LinearParams(
model.template_pair_embedder.linear
)
translations["evoformer"]["template_embedding"]["attention"] = AttentionParams(model.template_pointwise_att.mha)
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
# Sanity check
keys = list(data.keys())
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
# assert len(missing) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
print("incorrect keys:", incorrect)
print("missing keys:", missing)
# Set weights
assign(flat, data)
# Copyright 2022 DP Technology
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Install script for setuptools."""
from setuptools import find_packages
from setuptools import setup
setup(
name="unifold",
version="2.2.0",
description="An open-source platform for developing protein folding models beyond AlphaFold.",
author="DP Technology",
author_email="unifold@dp.tech",
license="Apache License, Version 2.0",
url="https://github.com/dptech-corp/Uni-Fold",
packages=find_packages(
exclude=["scripts", "tests", "example_data", "docker", "benchmark", "img", "evaluation", "notebooks"]
),
install_requires=[
"absl-py",
"biopython",
"ml-collections",
"numpy",
"pandas",
"scipy",
],
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
[ -z "${MASTER_PORT}" ] && MASTER_PORT=10087
[ -z "${MASTER_IP}" ] && MASTER_IP=127.0.0.1
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
[ -z "${update_freq}" ] && update_freq=1
[ -z "${total_step}" ] && total_step=80000
[ -z "${warmup_step}" ] && warmup_step=1000
[ -z "${decay_step}" ] && decay_step=50000
[ -z "${decay_ratio}" ] && decay_ratio=0.95
[ -z "${sd_prob}" ] && sd_prob=0.75
[ -z "${lr}" ] && lr=1e-3
[ -z "${seed}" ] && seed=42
[ -z "${OMPI_COMM_WORLD_SIZE}" ] && OMPI_COMM_WORLD_SIZE=1
[ -z "${OMPI_COMM_WORLD_RANK}" ] && OMPI_COMM_WORLD_RANK=0
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
echo "n_gpu per node" $n_gpu
echo "OMPI_COMM_WORLD_SIZE" $OMPI_COMM_WORLD_SIZE
echo "OMPI_COMM_WORLD_RANK" $OMPI_COMM_WORLD_RANK
echo "MASTER_IP" $MASTER_IP
echo "MASTER_PORT" $MASTER_PORT
echo "data" $1
echo "save_dir" $2
echo "decay_step" $decay_step
echo "warmup_step" $warmup_step
echo "decay_ratio" $decay_ratio
echo "lr" $lr
echo "total_step" $total_step
echo "update_freq" $update_freq
echo "seed" $seed
echo "data_folder:"
ls $1
echo "create folder for save"
mkdir -p $2
echo "start training"
model_name=$3
tmp_dir=`mktemp -d`
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port $MASTER_PORT --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP \
$(which unicore-train) $1 --user-dir unifold \
--num-workers 4 --ddp-backend=no_c10d \
--model-name $model_name \
--task af2 --loss af2 --arch af2 --sd-prob $sd_prob \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --clip-norm 0.0 --per-sample-clip-norm 0.1 --allreduce-fp32-grad \
--lr-scheduler exponential_decay --lr $lr --warmup-updates $warmup_step --decay-ratio $decay_ratio --decay-steps $decay_step --stair-decay --batch-size 1 \
--update-freq $update_freq --seed $seed --tensorboard-logdir $2/tsb/ \
--max-update $total_step --max-epoch 1 --log-interval 10 --log-format simple \
--save-interval-updates 500 --validate-interval-updates 500 --keep-interval-updates 40 --no-epoch-checkpoints \
--save-dir $2 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --bf16 --ema-decay 0.999 --data-buffer-size 32 --bf16-sr
rm -rf $tmp_dir
[ -z "${MASTER_PORT}" ] && MASTER_PORT=10086
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
mkdir -p $1
tmp_dir=`mktemp -d`
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port $MASTER_PORT $(which unicore-train) ./example_data/ --user-dir unifold \
--num-workers 8 --ddp-backend=no_c10d \
--task af2 --loss af2 --arch af2 \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --clip-norm 0.0 --per-sample-clip-norm 0.1 --allreduce-fp32-grad \
--lr-scheduler exponential_decay --lr 1e-3 --warmup-updates 1000 --decay-ratio 0.95 --decay-steps 50000 --batch-size 1 \
--update-freq 1 --seed 42 --tensorboard-logdir $1/tsb/ \
--max-update 1000 --max-epoch 1 --log-interval 10 --log-format simple \
--save-interval-updates 100 --validate-interval-updates 100 --keep-interval-updates 5 --no-epoch-checkpoints \
--save-dir $1 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --ema-decay 0.999 --bf16 --bf16-sr # for V100 or older GPUs, you can disable --bf16 for faster speed.
rm -rf $tmp_dir
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment