"configs/datasets/siqa/siqa_ppl_7845b0.py" did not exist on "7d346000bb8f1f7611f88dc8e003bdf8c9ae3ece"
Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
{
"PDB_AAHTFN": 1.0
}
\ No newline at end of file
{
"8d27_A": [
"8d27_A",
"8d27_B"
]
}
\ No newline at end of file
{
"8d27_A": 1.0
}
\ No newline at end of file
[ -z "${MASTER_PORT}" ] && MASTER_PORT=10087
[ -z "${MASTER_IP}" ] && MASTER_IP=127.0.0.1
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
[ -z "${update_freq}" ] && update_freq=1
[ -z "${total_step}" ] && total_step=10000
[ -z "${warmup_step}" ] && warmup_step=500
[ -z "${decay_step}" ] && decay_step=10000
[ -z "${decay_ratio}" ] && decay_ratio=1.0
[ -z "${lr}" ] && lr=5e-4
[ -z "${seed}" ] && seed=31
[ -z "${sd_prob}" ] && sd_prob=0.5
[ -z "${OMPI_COMM_WORLD_SIZE}" ] && OMPI_COMM_WORLD_SIZE=1
[ -z "${OMPI_COMM_WORLD_RANK}" ] && OMPI_COMM_WORLD_RANK=0
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
echo "n_gpu per node" $n_gpu
echo "OMPI_COMM_WORLD_SIZE" $OMPI_COMM_WORLD_SIZE
echo "OMPI_COMM_WORLD_RANK" $OMPI_COMM_WORLD_RANK
echo "MASTER_IP" $MASTER_IP
echo "MASTER_PORT" $MASTER_PORT
echo "data" $1
echo "save_dir" $2
echo "decay_step" $decay_step
echo "warmup_step" $warmup_step
echo "decay_ratio" $decay_ratio
echo "lr" $lr
echo "total_step" $total_step
echo "update_freq" $update_freq
echo "seed" $seed
echo "data_folder:"
ls $1
echo "create folder for save"
mkdir -p $2
echo "start training"
OPTION=""
if [ -f "$2/checkpoint_last.pt" ]; then
echo "ckp exists."
else
echo "finetuning from inital training..."
OPTION=" --finetune-from-model $3 --load-from-ema "
fi
model_name=$4
tmp_dir=`mktemp -d`
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port $MASTER_PORT --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP \
$(which unicore-train) $1 --user-dir unifold \
--num-workers 4 --ddp-backend=no_c10d \
--task af2 --loss af2 --arch af2 --sd-prob $sd_prob \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --clip-norm 0.0 --per-sample-clip-norm 0.1 --allreduce-fp32-grad \
--lr-scheduler exponential_decay --lr $lr --warmup-updates $warmup_step --decay-ratio $decay_ratio --decay-steps $decay_step --stair-decay --batch-size 1 \
--update-freq $update_freq --seed $seed --tensorboard-logdir $2/tsb/ \
--max-update $total_step --max-epoch 1 --log-interval 10 --log-format simple \
--save-interval-updates 500 --validate-interval-updates 500 --keep-interval-updates 40 --no-epoch-checkpoints \
--save-dir $2 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --bf16 --ema-decay 0.999 --data-buffer-size 32 --bf16-sr --model-name $model_name $OPTION
rm -rf $tmp_dir
[ -z "${MASTER_PORT}" ] && MASTER_PORT=10087
[ -z "${MASTER_IP}" ] && MASTER_IP=127.0.0.1
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
[ -z "${update_freq}" ] && update_freq=1
[ -z "${total_step}" ] && total_step=10000
[ -z "${warmup_step}" ] && warmup_step=500
[ -z "${decay_step}" ] && decay_step=10000
[ -z "${decay_ratio}" ] && decay_ratio=1.0
[ -z "${sd_prob}" ] && sd_prob=0.5
[ -z "${lr}" ] && lr=5e-4
[ -z "${seed}" ] && seed=31
[ -z "${OMPI_COMM_WORLD_SIZE}" ] && OMPI_COMM_WORLD_SIZE=1
[ -z "${OMPI_COMM_WORLD_RANK}" ] && OMPI_COMM_WORLD_RANK=0
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
echo "n_gpu per node" $n_gpu
echo "OMPI_COMM_WORLD_SIZE" $OMPI_COMM_WORLD_SIZE
echo "OMPI_COMM_WORLD_RANK" $OMPI_COMM_WORLD_RANK
echo "MASTER_IP" $MASTER_IP
echo "MASTER_PORT" $MASTER_PORT
echo "data" $1
echo "save_dir" $2
echo "decay_step" $decay_step
echo "warmup_step" $warmup_step
echo "decay_ratio" $decay_ratio
echo "lr" $lr
echo "total_step" $total_step
echo "update_freq" $update_freq
echo "seed" $seed
echo "data_folder:"
ls $1
echo "create folder for save"
mkdir -p $2
echo "start training"
OPTION=""
if [ -f "$2/checkpoint_last.pt" ]; then
echo "ckp exists."
else
echo "finetuning from inital training..."
OPTION=" --finetune-from-model $3 --load-from-ema "
fi
model_name=$4
tmp_dir=`mktemp -d`
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port $MASTER_PORT --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP \
$(which unicore-train) $1 --user-dir unifold \
--num-workers 4 --ddp-backend=no_c10d \
--task af2 --loss afm --arch af2 --sd-prob $sd_prob \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --clip-norm 0.0 --per-sample-clip-norm 0.1 --allreduce-fp32-grad \
--lr-scheduler exponential_decay --lr $lr --warmup-updates $warmup_step --decay-ratio $decay_ratio --decay-steps $decay_step --stair-decay --batch-size 1 \
--update-freq $update_freq --seed $seed --tensorboard-logdir $2/tsb/ \
--max-update $total_step --max-epoch 1 --log-interval 10 --log-format simple \
--save-interval-updates 500 --validate-interval-updates 500 --keep-interval-updates 40 --no-epoch-checkpoints \
--save-dir $2 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --bf16 --ema-decay 0.999 --data-buffer-size 32 --bf16-sr --model-name $model_name $OPTION
rm -rf $tmp_dir
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "jMGcXXPabEN4"
},
"source": [
"# Uni-Fold Notebook\n",
"\n",
"This notebook provides protein structure prediction service of [Uni-Fold](https://github.com/dptech-corp/Uni-Fold/) as well as [UF-Symmetry](https://www.biorxiv.org/content/10.1101/2022.08.30.505833v1). Predictions of both protein monomers and multimers are supported. The homology search process in this notebook is enabled with the [MMSeqs2](https://github.com/soedinglab/MMseqs2.git) server provided by [ColabFold](https://github.com/sokrypton/ColabFold). For more consistent results with the original AlphaFold(-Multimer), please refer to the open-source repository of [Uni-Fold](https://github.com/dptech-corp/Uni-Fold/), or our convenient web server at [Hermite™](https://hermite.dp.tech/).\n",
"\n",
"Please note that this notebook is provided as an early-access prototype, and is NOT an official product of DP Technology. It is provided for theoretical modeling only and caution should be exercised in its use. \n",
"\n",
"**Licenses**\n",
"\n",
"This Colab uses the [Uni-Fold model parameters](https://github.com/dptech-corp/Uni-Fold/#model-parameters-license) and its outputs are under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0).\n",
"\n",
"\n",
"**Citations**\n",
"\n",
"Please cite the following papers if you use this notebook:\n",
"\n",
"* Ziyao Li, Xuyang Liu, Weijie Chen, Fan Shen, Hangrui Bi, Guolin Ke, Linfeng Zhang. \"[Uni-Fold: An Open-Source Platform for Developing Protein Folding Models beyond AlphaFold.](https://www.biorxiv.org/content/10.1101/2022.08.04.502811v1)\" biorxiv (2022)\n",
"* Ziyao Li, Shuwen Yang, Xuyang Liu, Weijie Chen, Han Wen, Fan Shen, Guolin Ke, Linfeng Zhang. \"[Uni-Fold Symmetry: Harnessing Symmetry in Folding Large Protein Complexes.](https://www.biorxiv.org/content/10.1101/2022.08.30.505833v1)\" bioRxiv (2022)\n",
"* Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. \"[ColabFold: Making protein folding accessible to all.](https://www.nature.com/articles/s41592-022-01488-1)\" Nature Methods (2022)\n",
"\n",
"**Acknowledgements**\n",
"\n",
"The model architecture of Uni-Fold is largely based on [AlphaFold](https://doi.org/10.1038/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). The design of this notebook refers directly to [ColabFold](https://www.nature.com/articles/s41592-022-01488-1). We specially thank [@sokrypton](https://twitter.com/sokrypton) for his helpful suggestions to this notebook.\n",
"\n",
"Copyright © 2022 DP Technology. All rights reserved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "y0Evc150bEN7"
},
"outputs": [],
"source": [
"#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the _Play_ button \n",
"#@markdown on the left to download and import third-party software \n",
"#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/dptech-corp/Uni-Fold/#acknowledgements) in our readme.)\n",
"\n",
"#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n",
"%%bash\n",
"if [ ! -f ENV_READY ]; then\n",
" apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y -qq \\\n",
" hmmer \\\n",
" kalign\n",
"\n",
" # Install HHsuite.\n",
" wget -q https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-AVX2-Linux.tar.gz; tar xfz hhsuite-3.3.0-AVX2-Linux.tar.gz; ln -s $(pwd)/bin/* /usr/bin \n",
"\n",
" pip3 -q install py3dmol gdown\n",
"\n",
" touch ENV_READY\n",
"fi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "rETqvokYbEN9"
},
"outputs": [],
"source": [
"#@title Download Uni-Fold\n",
"\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"%%bash\n",
"GIT_REPO='https://github.com/dptech-corp/Uni-Fold'\n",
"UNICORE_URL='https://github.com/dptech-corp/Uni-Core/releases/download/0.0.1/unicore-0.0.1+cu113torch1.12.1-cp37-cp37m-linux_x86_64.whl'\n",
"PARAM_URL='https://drive.google.com/uc?id=1A9iXMYCwP0f_U0FgISJ_6BX7FXZtglvV'\n",
"UF_SYMM_PARAM_URL='https://drive.google.com/uc?id=1UNEGzmueQTxY05QIRweKHxOjr1ht-G_Q'\n",
"\n",
"if [ ! -f UNIFOLD_READY ]; then\n",
" wget ${UNICORE_URL} \n",
" pip3 -q install \"unicore-0.0.1+cu113torch1.12.1-cp37-cp37m-linux_x86_64.whl\"\n",
" git clone -b main ${GIT_REPO}\n",
" pip3 -q install ./Uni-Fold\n",
" gdown ${PARAM_URL}\n",
" tar -xzf \"unifold_params_2022-08-01.tar.gz\"\n",
" gdown ${UF_SYMM_PARAM_URL}\n",
" tar -xzf \"uf_symmetry_params_2022-09-06.tar.gz\"\n",
"\n",
" touch UNIFOLD_READY\n",
"fi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "j-xTD0QubEN-"
},
"outputs": [],
"source": [
"#@title Input protein sequence(s), then hit `Runtime` -> `Run all`\n",
"import os\n",
"import re\n",
"import hashlib\n",
"import random\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from typing import Dict, List, Sequence, Tuple, Union, Any, Optional\n",
"\n",
"from unifold.data import residue_constants, protein\n",
"from unifold.msa.utils import divide_multi_chains\n",
"\n",
"MIN_SINGLE_SEQUENCE_LENGTH = 16\n",
"MAX_SINGLE_SEQUENCE_LENGTH = 1000\n",
"MAX_MULTIMER_LENGTH = 1000\n",
"\n",
"output_dir_base = \"./prediction\"\n",
"os.makedirs(output_dir_base, exist_ok=True)\n",
"\n",
"def clean_and_validate_sequence(\n",
" input_sequence: str, min_length: int, max_length: int) -> str:\n",
" \"\"\"Checks that the input sequence is ok and returns a clean version of it.\"\"\"\n",
" # Remove all whitespaces, tabs and end lines; upper-case.\n",
" clean_sequence = input_sequence.translate(\n",
" str.maketrans('', '', ' \\n\\t')).upper()\n",
" aatypes = set(residue_constants.restypes) # 20 standard aatypes.\n",
" if not set(clean_sequence).issubset(aatypes):\n",
" raise ValueError(\n",
" f'Input sequence contains non-amino acid letters: '\n",
" f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '\n",
" 'amino acids as inputs.')\n",
" if len(clean_sequence) < min_length:\n",
" raise ValueError(\n",
" f'Input sequence is too short: {len(clean_sequence)} amino acids, '\n",
" f'while the minimum is {min_length}')\n",
" if len(clean_sequence) > max_length:\n",
" raise ValueError(\n",
" f'Input sequence is too long: {len(clean_sequence)} amino acids, while '\n",
" f'the maximum is {max_length}. You may be able to run it with the full '\n",
" f'Uni-Fold system depending on your resources (system memory, '\n",
" f'GPU memory).')\n",
" return clean_sequence\n",
"\n",
"\n",
"def validate_input(\n",
" input_sequences: Sequence[str],\n",
" symmetry_group: str,\n",
" min_length: int,\n",
" max_length: int,\n",
" max_multimer_length: int) -> Tuple[Sequence[str], bool]:\n",
" \"\"\"Validates and cleans input sequences and determines which model to use.\"\"\"\n",
" sequences = []\n",
"\n",
" for input_sequence in input_sequences:\n",
" if input_sequence.strip():\n",
" input_sequence = clean_and_validate_sequence(\n",
" input_sequence=input_sequence,\n",
" min_length=min_length,\n",
" max_length=max_length)\n",
" sequences.append(input_sequence)\n",
" \n",
" if symmetry_group != 'C1':\n",
" if symmetry_group.startswith('C') and symmetry_group[1:].isnumeric():\n",
" print(f'Using UF-Symmetry with group {symmetry_group}. If you do not '\n",
" f'want to use UF-Symmetry, please use `C1` and copy the AU '\n",
" f'sequences to the count in the assembly.')\n",
" is_multimer = (len(sequences) > 1)\n",
" return sequences, is_multimer, symmetry_group\n",
" else:\n",
" raise ValueError(f\"UF-Symmetry does not support symmetry group \"\n",
" f\"{symmetry_group} currently. Cyclic groups (Cx) are \"\n",
" f\"supported only.\")\n",
"\n",
" elif len(sequences) == 1:\n",
" print('Using the single-chain model.')\n",
" return sequences, False, None\n",
"\n",
" elif len(sequences) > 1:\n",
" total_multimer_length = sum([len(seq) for seq in sequences])\n",
" if total_multimer_length > max_multimer_length:\n",
" raise ValueError(f'The total length of multimer sequences is too long: '\n",
" f'{total_multimer_length}, while the maximum is '\n",
" f'{max_multimer_length}. Please use the full AlphaFold '\n",
" f'system for long multimers.')\n",
" print(f'Using the multimer model with {len(sequences)} sequences.')\n",
" return sequences, True, None\n",
"\n",
" else:\n",
" raise ValueError('No input amino acid sequence provided, please provide at '\n",
" 'least one sequence.')\n",
"\n",
"def add_hash(x,y):\n",
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
"\n",
"jobname = 'unifold_colab' #@param {type:\"string\"}\n",
"\n",
"sequence_1 = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVCTVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI' #@param {type:\"string\"}\n",
"sequence_2 = '' #@param {type:\"string\"}\n",
"sequence_3 = '' #@param {type:\"string\"}\n",
"sequence_4 = '' #@param {type:\"string\"}\n",
"\n",
"#@markdown Use symmetry group `C1` for default Uni-Fold predictions.\n",
"#@markdown Or, specify a **cyclic** symmetry group (e.g. `C4``) and\n",
"#@markdown the sequences of the asymmetric unit (i.e. **do not copy\n",
"#@markdown them multiple times**) to predict with UF-Symmetry.\n",
"\n",
"symmetry_group = 'C1' #@param {type:\"string\"}\n",
"\n",
"use_templates = True #@param {type:\"boolean\"}\n",
"msa_mode = \"MMseqs2\" #@param [\"MMseqs2\",\"single_sequence\"]\n",
"\n",
"input_sequences = [sequence_1, sequence_2, sequence_3, sequence_4]\n",
"\n",
"basejobname = \"\".join(input_sequences)\n",
"basejobname = re.sub(r'\\W+', '', basejobname)\n",
"target_id = add_hash(jobname, basejobname)\n",
"\n",
"# Validate the input.\n",
"sequences, is_multimer, symmetry_group = validate_input(\n",
" input_sequences=input_sequences,\n",
" symmetry_group=symmetry_group,\n",
" min_length=MIN_SINGLE_SEQUENCE_LENGTH,\n",
" max_length=MAX_SINGLE_SEQUENCE_LENGTH,\n",
" max_multimer_length=MAX_MULTIMER_LENGTH)\n",
"\n",
"descriptions = ['> '+target_id+' seq'+str(ii) for ii in range(len(sequences))]\n",
"\n",
"if is_multimer:\n",
" divide_multi_chains(target_id, output_dir_base, sequences, descriptions)\n",
" \n",
"s = []\n",
"for des, seq in zip(descriptions, sequences):\n",
" s += [des, seq]\n",
"\n",
"unique_sequences = []\n",
"[unique_sequences.append(x) for x in sequences if x not in unique_sequences]\n",
"\n",
"if len(unique_sequences)==1:\n",
" homooligomers_num = len(sequences)\n",
"else:\n",
" homooligomers_num = 1\n",
" \n",
"with open(f\"{jobname}.fasta\", \"w\") as f:\n",
" f.write(\"\\n\".join(s))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "QThPtPvlbEN_"
},
"outputs": [],
"source": [
"#@title Generate homogeneous features via ColabFold-MMSeqs2 server\n",
"#@markdown Acknowledge to [ColabFold](https://github.com/sokrypton/ColabFold.git)\n",
"\n",
"import tarfile\n",
"import requests\n",
"from tqdm import tqdm\n",
"import time\n",
"import logging\n",
"\n",
"from unifold.msa import templates, pipeline\n",
"from unifold.msa.tools import hhsearch\n",
"\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"DEFAULT_API_SERVER = \"https://api.colabfold.com\"\n",
"\n",
"def run_mmseqs2(x, prefix, use_env=True, \n",
" use_templates=False, use_pairing=False,\n",
" host_url=\"https://api.colabfold.com\") -> Tuple[List[str], List[str]]:\n",
" submission_endpoint = \"ticket/pair\" if use_pairing else \"ticket/msa\"\n",
"\n",
" def submit(seqs, mode, N=101):\n",
" n, query = N, \"\"\n",
" for seq in seqs:\n",
" query += f\">{n}\\n{seq}\\n\"\n",
" n += 1\n",
"\n",
" res = requests.post(f'{host_url}/{submission_endpoint}', data={'q':query,'mode': mode})\n",
" try:\n",
" out = res.json()\n",
" except ValueError:\n",
" logger.error(f\"Server didn't reply with json: {res.text}\")\n",
" out = {\"status\":\"ERROR\"}\n",
" return out\n",
"\n",
" def status(ID):\n",
" res = requests.get(f'{host_url}/ticket/{ID}')\n",
" try:\n",
" out = res.json()\n",
" except ValueError:\n",
" logger.error(f\"Server didn't reply with json: {res.text}\")\n",
" out = {\"status\":\"ERROR\"}\n",
" return out\n",
"\n",
" def download(ID, path):\n",
" res = requests.get(f'{host_url}/result/download/{ID}')\n",
" with open(path,\"wb\") as out: out.write(res.content)\n",
"\n",
" # process input x\n",
" seqs = [x] if isinstance(x, str) else x\n",
"\n",
" mode = \"env\"\n",
" if use_pairing:\n",
" mode = \"\"\n",
" use_templates = False\n",
" use_env = False\n",
"\n",
" # define path\n",
" path = f\"{prefix}\"\n",
" if not os.path.isdir(path): os.mkdir(path)\n",
"\n",
" # call mmseqs2 api\n",
" tar_gz_file = f'{path}/out_{mode}.tar.gz'\n",
" N,REDO = 101,True\n",
"\n",
" # deduplicate and keep track of order\n",
" seqs_unique = []\n",
" #TODO this might be slow for large sets\n",
" [seqs_unique.append(x) for x in seqs if x not in seqs_unique]\n",
" Ms = [N + seqs_unique.index(seq) for seq in seqs]\n",
" # lets do it!\n",
" if not os.path.isfile(tar_gz_file):\n",
" TIME_ESTIMATE = 150 * len(seqs_unique)\n",
" with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" while REDO:\n",
" pbar.set_description(\"SUBMIT\")\n",
"\n",
" # Resubmit job until it goes through\n",
" out = submit(seqs_unique, mode, N)\n",
" while out[\"status\"] in [\"UNKNOWN\", \"RATELIMIT\"]:\n",
" sleep_time = 5 + random.randint(0, 5)\n",
" logger.error(f\"Sleeping for {sleep_time}s. Reason: {out['status']}\")\n",
" # resubmit\n",
" time.sleep(sleep_time)\n",
" out = submit(seqs_unique, mode, N)\n",
"\n",
" if out[\"status\"] == \"ERROR\":\n",
" raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')\n",
"\n",
" if out[\"status\"] == \"MAINTENANCE\":\n",
" raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')\n",
"\n",
" # wait for job to finish\n",
" ID,TIME = out[\"id\"],0\n",
" pbar.set_description(out[\"status\"])\n",
" while out[\"status\"] in [\"UNKNOWN\",\"RUNNING\",\"PENDING\"]:\n",
" t = 5 + random.randint(0,5)\n",
" logger.error(f\"Sleeping for {t}s. Reason: {out['status']}\")\n",
" time.sleep(t)\n",
" out = status(ID)\n",
" pbar.set_description(out[\"status\"])\n",
" if out[\"status\"] == \"RUNNING\":\n",
" TIME += t\n",
" pbar.update(n=t)\n",
"\n",
" if out[\"status\"] == \"COMPLETE\":\n",
" if TIME < TIME_ESTIMATE:\n",
" pbar.update(n=(TIME_ESTIMATE-TIME))\n",
" REDO = False\n",
"\n",
" if out[\"status\"] == \"ERROR\":\n",
" REDO = False\n",
" raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')\n",
"\n",
" # Download results\n",
" download(ID, tar_gz_file)\n",
"\n",
" # prep list of a3m files\n",
" if use_pairing:\n",
" a3m_files = [f\"{path}/pair.a3m\"]\n",
" else:\n",
" a3m_files = [f\"{path}/uniref.a3m\"]\n",
" if use_env: a3m_files.append(f\"{path}/bfd.mgnify30.metaeuk30.smag30.a3m\")\n",
"\n",
" # extract a3m files\n",
" if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):\n",
" with tarfile.open(tar_gz_file) as tar_gz:\n",
" tar_gz.extractall(path)\n",
"\n",
" # templates\n",
" if use_templates:\n",
" templates = {}\n",
"\n",
" for line in open(f\"{path}/pdb70.m8\",\"r\"):\n",
" p = line.rstrip().split()\n",
" M,pdb,qid,e_value = p[0],p[1],p[2],p[10]\n",
" M = int(M)\n",
" if M not in templates: templates[M] = []\n",
" templates[M].append(pdb)\n",
"\n",
" template_paths = {}\n",
" for k,TMPL in templates.items():\n",
" TMPL_PATH = f\"{prefix}/templates_{k}\"\n",
" if not os.path.isdir(TMPL_PATH):\n",
" os.mkdir(TMPL_PATH)\n",
" TMPL_LINE = \",\".join(TMPL[:20])\n",
" os.system(f\"curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/\")\n",
" os.system(f\"cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex\")\n",
" os.system(f\"touch {TMPL_PATH}/pdb70_cs219.ffdata\")\n",
" template_paths[k] = TMPL_PATH\n",
"\n",
" # gather a3m lines\n",
" a3m_lines = {}\n",
" for a3m_file in a3m_files:\n",
" update_M,M = True,None\n",
" for line in open(a3m_file,\"r\"):\n",
" if len(line) > 0:\n",
" if \"\\x00\" in line:\n",
" line = line.replace(\"\\x00\",\"\")\n",
" update_M = True\n",
" if line.startswith(\">\") and update_M:\n",
" M = int(line[1:].rstrip())\n",
" update_M = False\n",
" if M not in a3m_lines: a3m_lines[M] = []\n",
" a3m_lines[M].append(line)\n",
"\n",
" # return results\n",
"\n",
" a3m_lines = [\"\".join(a3m_lines[n]) for n in Ms]\n",
"\n",
" if use_templates:\n",
" template_paths_ = []\n",
" for n in Ms:\n",
" if n not in template_paths:\n",
" template_paths_.append(None)\n",
" #print(f\"{n-N}\\tno_templates_found\")\n",
" else:\n",
" template_paths_.append(template_paths[n])\n",
" template_paths = template_paths_\n",
"\n",
"\n",
" return (a3m_lines, template_paths) if use_templates else a3m_lines\n",
"\n",
"def get_null_template(\n",
" query_sequence: Union[List[str], str], num_temp: int = 1\n",
") -> Dict[str, Any]:\n",
" ln = (\n",
" len(query_sequence)\n",
" if isinstance(query_sequence, str)\n",
" else sum(len(s) for s in query_sequence)\n",
" )\n",
" output_templates_sequence = \"A\" * ln\n",
" output_confidence_scores = np.full(ln, 1.0)\n",
"\n",
" templates_all_atom_positions = np.zeros(\n",
" (ln, templates.residue_constants.atom_type_num, 3)\n",
" )\n",
" templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))\n",
" templates_aatype = templates.residue_constants.sequence_to_onehot(\n",
" output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID\n",
" )\n",
" template_features = {\n",
" \"template_all_atom_positions\": np.tile(\n",
" templates_all_atom_positions[None], [num_temp, 1, 1, 1]\n",
" ),\n",
" \"template_all_atom_masks\": np.tile(\n",
" templates_all_atom_masks[None], [num_temp, 1, 1]\n",
" ),\n",
" \"template_sequence\": [f\"none\".encode()] * num_temp,\n",
" \"template_aatype\": np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]),\n",
" \"template_domain_names\": [f\"none\".encode()] * num_temp,\n",
" \"template_sum_probs\": np.zeros([num_temp], dtype=np.float32),\n",
" }\n",
" return template_features\n",
"\n",
"\n",
"def get_template(\n",
" a3m_lines: str, template_path: str, query_sequence: str\n",
") -> Dict[str, Any]:\n",
" template_featurizer = templates.HhsearchHitFeaturizer(\n",
" mmcif_dir=template_path,\n",
" max_template_date=\"2100-01-01\",\n",
" max_hits=20,\n",
" kalign_binary_path=\"kalign\",\n",
" release_dates_path=None,\n",
" obsolete_pdbs_path=None,\n",
" )\n",
"\n",
" hhsearch_pdb70_runner = hhsearch.HHSearch(\n",
" binary_path=\"hhsearch\", databases=[f\"{template_path}/pdb70\"]\n",
" )\n",
"\n",
" hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines)\n",
" hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result)\n",
" templates_result = template_featurizer.get_templates(\n",
" query_sequence=query_sequence, hits=hhsearch_hits\n",
" )\n",
" return dict(templates_result.features)\n",
" \n",
"def get_msa_and_templates(\n",
" jobname: str,\n",
" query_seqs_unique: Union[str, List[str]],\n",
" result_dir: Path,\n",
" msa_mode: str,\n",
" use_templates: bool,\n",
" homooligomers_num: int = 1,\n",
" host_url: str = DEFAULT_API_SERVER,\n",
") -> Tuple[\n",
" Optional[List[str]], Optional[List[str]], List[str], List[int], List[Dict[str, Any]]\n",
"]:\n",
" \n",
" use_env = msa_mode == \"MMseqs2\"\n",
"\n",
" template_features = []\n",
" if use_templates:\n",
" a3m_lines_mmseqs2, template_paths = run_mmseqs2(\n",
" query_seqs_unique,\n",
" str(result_dir.joinpath(jobname)),\n",
" use_env,\n",
" use_templates=True,\n",
" host_url=host_url,\n",
" )\n",
" if template_paths is None:\n",
" logger.info(\"No template detected\")\n",
" for index in range(0, len(query_seqs_unique)):\n",
" template_feature = get_null_template(query_seqs_unique[index])\n",
" template_features.append(template_feature)\n",
" else:\n",
" for index in range(0, len(query_seqs_unique)):\n",
" if template_paths[index] is not None:\n",
" template_feature = get_template(\n",
" a3m_lines_mmseqs2[index],\n",
" template_paths[index],\n",
" query_seqs_unique[index],\n",
" )\n",
" if len(template_feature[\"template_domain_names\"]) == 0:\n",
" template_feature = get_null_template(query_seqs_unique[index])\n",
" logger.info(f\"Sequence {index} found no templates\")\n",
" else:\n",
" logger.info(\n",
" f\"Sequence {index} found templates: {template_feature['template_domain_names'].astype(str).tolist()}\"\n",
" )\n",
" else:\n",
" template_feature = get_null_template(query_seqs_unique[index])\n",
" logger.info(f\"Sequence {index} found no templates\")\n",
"\n",
" template_features.append(template_feature)\n",
" else:\n",
" for index in range(0, len(query_seqs_unique)):\n",
" template_feature = get_null_template(query_seqs_unique[index])\n",
" template_features.append(template_feature)\n",
"\n",
"\n",
" if msa_mode == \"single_sequence\":\n",
" a3m_lines = []\n",
" num = 101\n",
" for i, seq in enumerate(query_seqs_unique):\n",
" a3m_lines.append(\">\" + str(num + i) + \"\\n\" + seq)\n",
" else:\n",
" # find normal a3ms\n",
" a3m_lines = run_mmseqs2(\n",
" query_seqs_unique,\n",
" str(result_dir.joinpath(jobname)),\n",
" use_env,\n",
" use_pairing=False,\n",
" host_url=host_url,\n",
" )\n",
" if len(query_seqs_unique)>1:\n",
" # find paired a3m if not a homooligomers\n",
" paired_a3m_lines = run_mmseqs2(\n",
" query_seqs_unique,\n",
" str(result_dir.joinpath(jobname)),\n",
" use_env,\n",
" use_pairing=True,\n",
" host_url=host_url,\n",
" )\n",
" else:\n",
" num = 101\n",
" paired_a3m_lines = []\n",
" for i in range(0, homooligomers_num):\n",
" paired_a3m_lines.append(\n",
" \">\" + str(num + i) + \"\\n\" + query_seqs_unique[0] + \"\\n\"\n",
" )\n",
"\n",
" return (\n",
" a3m_lines,\n",
" paired_a3m_lines,\n",
" template_features,\n",
" )\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "RWwTgjo4bEOB"
},
"outputs": [],
"source": [
"#@title Process features for Uni-Fold prediction\n",
"import pickle\n",
"import gzip\n",
"from unifold.msa import parsers\n",
"from unifold.msa import pipeline\n",
"from unifold.data.utils import compress_features\n",
"from unifold.data.protein import PDB_CHAIN_IDS\n",
"\n",
"result_dir = Path(output_dir_base)\n",
"output_dir = os.path.join(output_dir_base, target_id)\n",
"\n",
"(\n",
" unpaired_msa,\n",
" paired_msa,\n",
" template_results,\n",
") = get_msa_and_templates(\n",
" target_id,\n",
" unique_sequences,\n",
" result_dir=result_dir,\n",
" msa_mode=msa_mode,\n",
" use_templates=use_templates,\n",
" homooligomers_num = homooligomers_num\n",
")\n",
"\n",
"\n",
"for idx, seq in enumerate(unique_sequences):\n",
" chain_id = PDB_CHAIN_IDS[idx]\n",
" sequence_features = pipeline.make_sequence_features(\n",
" sequence=seq, description=f'> {jobname} seq {chain_id}', num_res=len(seq)\n",
" )\n",
" monomer_msa = parsers.parse_a3m(unpaired_msa[idx])\n",
" msa_features = pipeline.make_msa_features([monomer_msa])\n",
" template_features = template_results[idx]\n",
" feature_dict = {**sequence_features, **msa_features, **template_features}\n",
" feature_dict = compress_features(feature_dict)\n",
" features_output_path = os.path.join(\n",
" output_dir, \"{}.feature.pkl.gz\".format(chain_id)\n",
" )\n",
" pickle.dump(\n",
" feature_dict, \n",
" gzip.GzipFile(features_output_path, \"wb\"), \n",
" protocol=4\n",
" )\n",
" if is_multimer:\n",
" multimer_msa = parsers.parse_a3m(paired_msa[idx])\n",
" pair_features = pipeline.make_msa_features([multimer_msa])\n",
" pair_feature_dict = compress_features(pair_features)\n",
" uniprot_output_path = os.path.join(\n",
" output_dir, \"{}.uniprot.pkl.gz\".format(chain_id)\n",
" )\n",
" pickle.dump(\n",
" pair_feature_dict,\n",
" gzip.GzipFile(uniprot_output_path, \"wb\"),\n",
" protocol=4,\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "RJUxaO7Ofw1L"
},
"outputs": [],
"source": [
"#@title Uni-Fold prediction\n",
"\n",
"from unittest import result\n",
"import torch\n",
"import json\n",
"from unifold.config import model_config\n",
"from unifold.modules.alphafold import AlphaFold\n",
"from unifold.dataset import load_and_process, UnifoldDataset\n",
"from unicore.utils import (\n",
" tensor_tree_map,\n",
")\n",
"from unifold.symmetry import (\n",
" UFSymmetry,\n",
" uf_symmetry_config,\n",
" assembly_from_prediction,\n",
" load_and_process_symmetry,\n",
")\n",
"\n",
"def automatic_chunk_size(seq_len):\n",
" if seq_len < 512:\n",
" chunk_size = 256\n",
" elif seq_len < 1024:\n",
" chunk_size = 128\n",
" elif seq_len < 2048:\n",
" chunk_size = 32\n",
" elif seq_len < 3072:\n",
" chunk_size = 16\n",
" else:\n",
" chunk_size = 1\n",
" return chunk_size\n",
"\n",
"\n",
"def load_feature_for_one_target(\n",
" config, data_folder, seed=0, is_multimer=False, use_uniprot=False, symmetry_group=None,\n",
"):\n",
" if not is_multimer:\n",
" uniprot_msa_dir = None\n",
" sequence_ids = [\"A\"]\n",
" if use_uniprot:\n",
" uniprot_msa_dir = data_folder\n",
"\n",
" else:\n",
" uniprot_msa_dir = data_folder\n",
" sequence_ids = open(os.path.join(data_folder, \"chains.txt\")).readline().split()\n",
" \n",
" if symmetry_group is None:\n",
" batch, _ = load_and_process(\n",
" config=config.data,\n",
" mode=\"predict\",\n",
" seed=seed,\n",
" batch_idx=None,\n",
" data_idx=0,\n",
" is_distillation=False,\n",
" sequence_ids=sequence_ids,\n",
" monomer_feature_dir=data_folder,\n",
" uniprot_msa_dir=uniprot_msa_dir,\n",
" )\n",
" \n",
" else:\n",
" batch, _ = load_and_process_symmetry(\n",
" config=config.data,\n",
" mode=\"predict\",\n",
" seed=seed,\n",
" batch_idx=None,\n",
" data_idx=0,\n",
" is_distillation=False,\n",
" symmetry=symmetry_group,\n",
" sequence_ids=sequence_ids,\n",
" monomer_feature_dir=data_folder,\n",
" uniprot_msa_dir=uniprot_msa_dir,\n",
" )\n",
" batch = UnifoldDataset.collater([batch])\n",
" return batch\n",
"\n",
"if symmetry_group is not None:\n",
" model_name = \"uf_symmetry\"\n",
" param_path = \"./uf_symmetry.pt\"\n",
"elif is_multimer:\n",
" model_name = \"multimer_ft\"\n",
" param_path = \"./multimer.unifold.pt\"\n",
"else:\n",
" model_name = \"model_2_ft\"\n",
" param_path = \"./monomer.unifold.pt\"\n",
"\n",
"max_recycling_iters = 3 #@param {type:\"integer\"}\n",
"num_ensembles = 2 #@param {type:\"integer\"}\n",
"manual_seed = 42 #@param {type:\"integer\"}\n",
"times = 3 #@param {type:\"integer\"}\n",
"\n",
"if symmetry_group is None:\n",
" config = model_config(model_name)\n",
"else:\n",
" config = uf_symmetry_config()\n",
"config.data.common.max_recycling_iters = max_recycling_iters\n",
"config.globals.max_recycling_iters = max_recycling_iters\n",
"config.data.predict.num_ensembles = num_ensembles\n",
"\n",
"# faster prediction with large chunk\n",
"config.globals.chunk_size = 128\n",
"model = AlphaFold(config) if symmetry_group is None else UFSymmetry(config)\n",
"print(\"start to load params {}\".format(param_path))\n",
"state_dict = torch.load(param_path)[\"ema\"][\"params\"]\n",
"state_dict = {\".\".join(k.split(\".\")[1:]): v for k, v in state_dict.items()}\n",
"model.load_state_dict(state_dict)\n",
"model = model.to(\"cuda:0\")\n",
"model.eval()\n",
"model.inference_mode()\n",
"\n",
"# data path is based on target_name\n",
"cur_param_path_postfix = os.path.split(param_path)[-1]\n",
"\n",
"print(\"start to predict {}\".format(target_id))\n",
"plddts = {}\n",
"ptms = {}\n",
"best_protein = None\n",
"best_score = 0\n",
"best_plddt = None\n",
"best_pae = None\n",
"\n",
"for seed in range(times):\n",
" cur_seed = hash((manual_seed, seed)) % 100000\n",
" batch = load_feature_for_one_target(\n",
" config,\n",
" output_dir,\n",
" cur_seed,\n",
" is_multimer=is_multimer,\n",
" use_uniprot=is_multimer,\n",
" symmetry_group=symmetry_group,\n",
" )\n",
" seq_len = batch[\"aatype\"].shape[-1]\n",
" model.globals.chunk_size = automatic_chunk_size(seq_len)\n",
"\n",
" with torch.no_grad():\n",
" batch = {\n",
" k: torch.as_tensor(v, device=\"cuda:0\")\n",
" for k, v in batch.items()\n",
" }\n",
" shapes = {k: v.shape for k, v in batch.items()}\n",
" print(shapes)\n",
" t = time.perf_counter()\n",
" out = model(batch)\n",
" print(f\"Inference time: {time.perf_counter() - t}\")\n",
"\n",
" def to_float(x):\n",
" if x.dtype == torch.bfloat16 or x.dtype == torch.half:\n",
" return x.float()\n",
" else:\n",
" return x\n",
"\n",
" # Toss out the recycling dimensions --- we don't need them anymore\n",
" batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)\n",
" batch = tensor_tree_map(to_float, batch)\n",
" out = tensor_tree_map(lambda t: t[0, ...], out)\n",
" out = tensor_tree_map(to_float, out)\n",
" batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)\n",
" out = tensor_tree_map(lambda x: np.array(x.cpu()), out)\n",
"\n",
" plddt = out[\"plddt\"]\n",
" mean_plddt = np.mean(plddt)\n",
" plddt_b_factors = np.repeat(\n",
" plddt[..., None], residue_constants.atom_type_num, axis=-1\n",
" )\n",
" # TODO: , may need to reorder chains, based on entity_ids\n",
" if symmetry_group is None:\n",
" cur_protein = protein.from_prediction(\n",
" features=batch, result=out, b_factors=plddt_b_factors\n",
" )\n",
" else:\n",
" plddt_b_factors_assembly = np.concatenate(\n",
" [plddt_b_factors for _ in range(batch[\"symmetry_opers\"].shape[0])])\n",
" cur_protein = assembly_from_prediction(\n",
" result=out, b_factors=plddt_b_factors_assembly,\n",
" )\n",
" cur_save_name = (\n",
" f\"{cur_param_path_postfix}_{cur_seed}\"\n",
" )\n",
" plddts[cur_save_name] = str(mean_plddt)\n",
" if is_multimer and symmetry_group is None:\n",
" ptms[cur_save_name] = str(np.mean(out[\"iptm+ptm\"]))\n",
" with open(os.path.join(output_dir, cur_save_name + '.pdb'), \"w\") as f:\n",
" f.write(protein.to_pdb(cur_protein))\n",
"\n",
" if is_multimer and symmetry_group is None:\n",
" mean_ptm = np.mean(out[\"iptm+ptm\"])\n",
" if mean_ptm>best_score:\n",
" best_protein = cur_protein\n",
" best_pae = out[\"predicted_aligned_error\"]\n",
" best_plddt = out[\"plddt\"]\n",
" best_score = mean_ptm\n",
" else:\n",
" if mean_plddt>best_score:\n",
" best_protein = cur_protein\n",
" best_plddt = out[\"plddt\"]\n",
" best_score = mean_plddt\n",
"\n",
"print(\"plddts\", plddts)\n",
"score_name = f\"{model_name}_{cur_param_path_postfix}\"\n",
"plddt_fname = score_name + \"_plddt.json\"\n",
"json.dump(plddts, open(os.path.join(output_dir, plddt_fname), \"w\"), indent=4)\n",
"if ptms:\n",
" print(\"ptms\", ptms)\n",
" ptm_fname = score_name + \"_ptm.json\"\n",
" json.dump(ptms, open(os.path.join(output_dir, ptm_fname), \"w\"), indent=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "kryWdmg0jZwT"
},
"outputs": [],
"source": [
"#@title Show the protein structure\n",
"\n",
"# Construct multiclass b-factors to indicate confidence bands\n",
"# 0=very low, 1=low, 2=confident, 3=very high\n",
"# Color bands for visualizing plddt\n",
"import py3Dmol\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import LinearSegmentedColormap\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output\n",
"\n",
"\n",
"show_sidechains = False #@param {type:\"boolean\"}\n",
"dpi = 100 #@param {type:\"integer\"}\n",
"\n",
"to_visualize_pdb = protein.to_pdb(best_protein)\n",
"\n",
"PLDDT_BANDS = [(0., 0.50, '#FF7D45'),\n",
" (0.50, 0.70, '#FFDB13'),\n",
" (0.70, 0.90, '#65CBF3'),\n",
" (0.90, 1.00, '#0053D6')]\n",
"\n",
"\n",
"# --- Visualise the prediction & confidence ---\n",
"def plot_plddt_legend():\n",
" \"\"\"Plots the legend for pLDDT.\"\"\"\n",
" thresh = ['Very low (pLDDT < 50)',\n",
" 'Low (70 > pLDDT > 50)',\n",
" 'Confident (90 > pLDDT > 70)',\n",
" 'Very high (pLDDT > 90)']\n",
"\n",
" colors = [x[2] for x in PLDDT_BANDS]\n",
"\n",
" plt.figure(figsize=(2, 2))\n",
" for c in colors:\n",
" plt.bar(0, 0, color=c)\n",
" plt.legend(thresh, frameon=False, loc='center', fontsize=20)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" ax = plt.gca()\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" ax.spines['left'].set_visible(False)\n",
" ax.spines['bottom'].set_visible(False)\n",
" plt.title('Model Confidence', fontsize=20, pad=20)\n",
" return plt\n",
"\n",
"\n",
"if is_multimer and symmetry_group is None:\n",
" multichain_view = py3Dmol.view(width=800, height=600)\n",
" multichain_view.addModelsAsFrames(to_visualize_pdb)\n",
" multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n",
" multichain_view.setStyle({'model': -1}, multichain_style)\n",
" multichain_view.zoomTo()\n",
" multichain_view.show()\n",
"\n",
"# Color the structure by per-residue pLDDT\n",
"view = py3Dmol.view(width=800, height=600)\n",
"view.addModelsAsFrames(to_visualize_pdb)\n",
"style = {'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}}\n",
"if show_sidechains:\n",
" style['stick'] = {}\n",
"view.setStyle({'model':-1}, style)\n",
"view.zoomTo()\n",
"\n",
"grid = GridspecLayout(1, 2)\n",
"out = Output()\n",
"with out:\n",
" view.show()\n",
"grid[0, 0] = out\n",
"\n",
"out = Output()\n",
"with out:\n",
" plot_plddt_legend().show()\n",
"grid[0, 1] = out\n",
"\n",
"display.display(grid)\n",
"\n",
"# Display pLDDT and predicted aligned error (if output by the model).\n",
"if is_multimer and symmetry_group is None:\n",
" num_plots = 2\n",
"else:\n",
" num_plots = 1\n",
"\n",
"plt.figure(figsize=[8 * num_plots , 6])\n",
"plt.subplot(1, num_plots, 1)\n",
"plt.plot(plddt*100)\n",
"plt.title('Predicted LDDT')\n",
"plt.xlabel('Residue')\n",
"plt.ylabel('pLDDT')\n",
"plt.grid()\n",
"plddt_svg_path = os.path.join(output_dir, 'plddt.svg')\n",
"plt.savefig(plddt_svg_path, dpi=dpi, bbox_inches='tight')\n",
"\n",
"\n",
"if num_plots == 2:\n",
" plt.subplot(1, 2, 2)\n",
" max_pae = np.max(best_pae)\n",
" colors = ['#0F006F','#245AE6','#55CCFF','#FFFFFF']\n",
"\n",
" cmap = LinearSegmentedColormap.from_list('mymap', colors)\n",
" im = plt.imshow(best_pae, vmin=0., vmax=max_pae, cmap=cmap)\n",
" plt.colorbar(im, fraction=0.046, pad=0.04)\n",
"\n",
" # Display lines at chain boundaries.\n",
" total_num_res = best_protein.residue_index.shape[-1]\n",
" chain_ids = best_protein.chain_index\n",
" for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):\n",
" if chain_boundary.size:\n",
" plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')\n",
" plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')\n",
"\n",
" plt.title('Predicted Aligned Error')\n",
" plt.xlabel('Scored residue')\n",
" plt.ylabel('Aligned residue')\n",
" pae_svg_path = os.path.join(output_dir, 'pae.svg')\n",
" plt.savefig(pae_svg_path, dpi=dpi, bbox_inches='tight')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form"
},
"outputs": [],
"source": [
"#@title Download the prediction\n",
"#@markdown **The content of zip file**:\n",
"#@markdown 1. PDB formatted structures\n",
"#@markdown 2. Json file of the model quality (pLDDT and pTM for multimer)\n",
"#@markdown 2. Plots of the model quality (pLDDT and PAE for multimer)\n",
"\n",
"from google.colab import files\n",
"\n",
"\n",
"plddt_file = os.path.join(output_dir, plddt_fname)\n",
"\n",
"pdb_files = [os.path.join(output_dir, pdb_name + '.pdb') for pdb_name in plddts]\n",
"file_lists = pdb_files + [\n",
" plddt_file, plddt_svg_path\n",
"]\n",
"if is_multimer and symmetry_group is None:\n",
" ptm_file = os.path.join(output_dir, ptm_fname)\n",
" file_lists.append(ptm_file)\n",
" file_lists.append(pae_svg_path)\n",
"\n",
"!zip -q {target_id}.zip {\" \".join(file_lists)}\n",
"files.download(f'{target_id}.zip')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "unifold.ipynb",
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.8.10 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
jackhmmer,
hhblits,
hhsearch,
hmmsearch,
mmbuild,
kalign,
Starting prediction...
/usr/local/lib/python3.7/site-packages/torch/jit/annotations.py:289: UserWarning: TorchScript will treat type annotations of Tensor dtype-specific subtypes as if they are normal Tensors. dtype constraints are not enforced in compilation either.
warnings.warn("TorchScript will treat type annotations of Tensor "
start to load params /root/Uni-Fold-main/Alphafold/monomer.unifold.pt
start to predict T1024
{'aatype': torch.Size([1, 1, 408]), 'residue_index': torch.Size([1, 1, 408]), 'seq_length': torch.Size([1, 1]), 'template_aatype': torch.Size([1, 1, 4, 408]), 'template_all_atom_mask': torch.Size([1, 1, 4, 408, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 408, 37, 3]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([8, 1]), 'seq_mask': torch.Size([1, 1, 408]), 'msa_mask': torch.Size([8, 1, 508, 408]), 'msa_row_mask': torch.Size([8, 1, 508]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 408, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 408]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 408, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 408, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 408, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 408, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 408, 37]), 'atom14_atom_exists': torch.Size([1, 1, 408, 14]), 'atom37_atom_exists': torch.Size([1, 1, 408, 37]), 'target_feat': torch.Size([1, 1, 408, 22]), 'extra_msa': torch.Size([8, 1, 1024, 408]), 'extra_msa_mask': torch.Size([8, 1, 1024, 408]), 'extra_msa_row_mask': torch.Size([8, 1, 1024]), 'bert_mask': torch.Size([8, 1, 508, 408]), 'true_msa': torch.Size([8, 1, 508, 408]), 'extra_msa_has_deletion': torch.Size([8, 1, 1024, 408]), 'extra_msa_deletion_value': torch.Size([8, 1, 1024, 408]), 'msa_feat': torch.Size([8, 1, 508, 408, 49])}
Inference time: 138.50844680101727
Starting prediction...
/usr/local/lib/python3.7/site-packages/torch/jit/annotations.py:289: UserWarning: TorchScript will treat type annotations of Tensor dtype-specific subtypes as if they are normal Tensors. dtype constraints are not enforced in compilation either.
warnings.warn("TorchScript will treat type annotations of Tensor "
start to load params /root/Uni-Fold-main/Alphafold/multimer.unifold.pt
start to predict H1036
{'aatype': torch.Size([1, 1, 856]), 'residue_index': torch.Size([1, 1, 856]), 'seq_length': torch.Size([1, 1]), 'msa_chains': torch.Size([8, 1, 252, 1]), 'template_aatype': torch.Size([1, 1, 4, 856]), 'template_all_atom_mask': torch.Size([1, 1, 4, 856, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 856, 37, 3]), 'asym_id': torch.Size([1, 1, 856]), 'sym_id': torch.Size([1, 1, 856]), 'entity_id': torch.Size([1, 1, 856]), 'num_sym': torch.Size([1, 1, 856]), 'assembly_num_chains': torch.Size([1, 1, 1]), 'cluster_bias_mask': torch.Size([1, 1, 252]), 'bert_mask': torch.Size([8, 1, 252, 856]), 'msa_mask': torch.Size([8, 1, 252, 856]), 'asym_len': torch.Size([1, 1, 3]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([8, 1]), 'seq_mask': torch.Size([1, 1, 856]), 'msa_row_mask': torch.Size([8, 1, 252]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 856, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 856]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 856, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 856, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 856, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 856, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 856, 37]), 'atom14_atom_exists': torch.Size([1, 1, 856, 14]), 'atom37_atom_exists': torch.Size([1, 1, 856, 37]), 'target_feat': torch.Size([1, 1, 856, 22]), 'extra_msa': torch.Size([8, 1, 1152, 856]), 'extra_msa_mask': torch.Size([8, 1, 1152, 856]), 'extra_msa_row_mask': torch.Size([8, 1, 1152]), 'true_msa': torch.Size([8, 1, 252, 856]), 'msa_feat': torch.Size([8, 1, 252, 856, 49]), 'extra_msa_has_deletion': torch.Size([8, 1, 1152, 856]), 'extra_msa_deletion_value': torch.Size([8, 1, 1152, 856])}
Inference time: 410.5106331880088
{'aatype': torch.Size([1, 1, 856]), 'residue_index': torch.Size([1, 1, 856]), 'seq_length': torch.Size([1, 1]), 'msa_chains': torch.Size([8, 1, 252, 1]), 'template_aatype': torch.Size([1, 1, 4, 856]), 'template_all_atom_mask': torch.Size([1, 1, 4, 856, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 856, 37, 3]), 'asym_id': torch.Size([1, 1, 856]), 'sym_id': torch.Size([1, 1, 856]), 'entity_id': torch.Size([1, 1, 856]), 'num_sym': torch.Size([1, 1, 856]), 'assembly_num_chains': torch.Size([1, 1, 1]), 'cluster_bias_mask': torch.Size([1, 1, 252]), 'bert_mask': torch.Size([8, 1, 252, 856]), 'msa_mask': torch.Size([8, 1, 252, 856]), 'asym_len': torch.Size([1, 1, 3]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([8, 1]), 'seq_mask': torch.Size([1, 1, 856]), 'msa_row_mask': torch.Size([8, 1, 252]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 856, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 856]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 856, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 856, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 856, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 856, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 856, 37]), 'atom14_atom_exists': torch.Size([1, 1, 856, 14]), 'atom37_atom_exists': torch.Size([1, 1, 856, 37]), 'target_feat': torch.Size([1, 1, 856, 22]), 'extra_msa': torch.Size([8, 1, 1152, 856]), 'extra_msa_mask': torch.Size([8, 1, 1152, 856]), 'extra_msa_row_mask': torch.Size([8, 1, 1152]), 'true_msa': torch.Size([8, 1, 252, 856]), 'msa_feat': torch.Size([8, 1, 252, 856, 49]), 'extra_msa_has_deletion': torch.Size([8, 1, 1152, 856]), 'extra_msa_deletion_value': torch.Size([8, 1, 1152, 856])}
Inference time: 406.87861637599417
{'aatype': torch.Size([1, 1, 856]), 'residue_index': torch.Size([1, 1, 856]), 'seq_length': torch.Size([1, 1]), 'msa_chains': torch.Size([8, 1, 252, 1]), 'template_aatype': torch.Size([1, 1, 4, 856]), 'template_all_atom_mask': torch.Size([1, 1, 4, 856, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 856, 37, 3]), 'asym_id': torch.Size([1, 1, 856]), 'sym_id': torch.Size([1, 1, 856]), 'entity_id': torch.Size([1, 1, 856]), 'num_sym': torch.Size([1, 1, 856]), 'assembly_num_chains': torch.Size([1, 1, 1]), 'cluster_bias_mask': torch.Size([1, 1, 252]), 'bert_mask': torch.Size([8, 1, 252, 856]), 'msa_mask': torch.Size([8, 1, 252, 856]), 'asym_len': torch.Size([1, 1, 3]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([8, 1]), 'seq_mask': torch.Size([1, 1, 856]), 'msa_row_mask': torch.Size([8, 1, 252]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 856, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 856]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 856, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 856, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 856, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 856, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 856, 37]), 'atom14_atom_exists': torch.Size([1, 1, 856, 14]), 'atom37_atom_exists': torch.Size([1, 1, 856, 37]), 'target_feat': torch.Size([1, 1, 856, 22]), 'extra_msa': torch.Size([8, 1, 1152, 856]), 'extra_msa_mask': torch.Size([8, 1, 1152, 856]), 'extra_msa_row_mask': torch.Size([8, 1, 1152]), 'true_msa': torch.Size([8, 1, 252, 856]), 'msa_feat': torch.Size([8, 1, 252, 856, 49]), 'extra_msa_has_deletion': torch.Size([8, 1, 1152, 856]), 'extra_msa_deletion_value': torch.Size([8, 1, 1152, 856])}
Inference time: 406.9402783330006
plddts {'multimer_ft_multimer.unifold.pt_20281': '0.4212213', 'multimer_ft_multimer.unifold.pt_2806': '0.41139442', 'multimer_ft_multimer.unifold.pt_55231': '0.4146896'}
ptms {'multimer_ft_multimer.unifold.pt_20281': '0.99775934', 'multimer_ft_multimer.unifold.pt_2806': '0.99926674', 'multimer_ft_multimer.unifold.pt_55231': '0.99753684'}
#!/bin/bash
export DTKROOT=/opt/dtk-22.04.2
export AMDGPU_TARGETS="gfx906"
export ROCMVER=22.04.2
export DTK_HOME=/opt/dtk-22.04.2
export ROCM_PATH=${DTK_HOME}
export HIP_PATH=${DTK_HOME}/hip
export PATH=${DTK_HOME}/bin:${DTK_HOME}/llvm/bin:${DTK_HOME}/hip/bin:${DTK_HOME}/miopen/bin:$PATH
export LD_LIBRARY_PATH=${DTK_HOME}/lib:${DTK_HOME}/lib64:${DTK_HOME}/hip/lib:${DTK_HOME}/llvm/lib:${DTK_HOME}/miopen/lib:$LD_LIBRARY_PATH
export INCLUDE=${DTK_HOME}/include:${DTK_HOME}/hip/include:${DTK_HOME}/llvm/include:$INCLUDE
export C_INCLUDE_PATH=${DTK_HOME}/include:${DTK_HOME}/hip/include:${DTK_HOME}/llvm/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=${DTK_HOME}/include:$CPLUS_INCLUDE_PATH
export MIOPEN_SYSTEM_DB_PATH=${DTK_HOME}/miopen/share/miopen/db/
source /root/env.sh
cd /root/Uni-Fold-main
fasta_path=/root/Uni-Fold-main/data/T1024.fasta
input_dir_base=/root/Uni-Fold-main/data
output_dir_base=/root/Uni-Fold-main/data
database_dir=/alphafold/alphafold/
max_template_date=2020-05-01
model_name=model_2_ft
param_path=/root/Uni-Fold-main/Alphafold/monomer.unifold.pt
echo "Starting homogeneous searching..."
python unifold/homo_search.py \
--fasta_path=$fasta_path \
--max_template_date=$max_template_date \
--output_dir=$output_dir_base \
--uniref90_database_path=$database_dir/uniref90/uniref90.fasta \
--mgnify_database_path=$database_dir/mgnify/mgy_clusters_2018_12.fa \
--bfd_database_path=$database_dir/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--uniclust30_database_path=$database_dir/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniprot_database_path=$database_dir/uniprot/uniprot.fasta \
--pdb_seqres_database_path=$database_dir/pdb_seqres/pdb_seqres.txt \
--template_mmcif_dir=$database_dir/pdb_mmcif/mmcif_files \
--obsolete_pdbs_path=$database_dir/pdb_mmcif/obsolete.dat \
--use_precomputed_msas=True
echo "Starting prediction..."
fasta_file=$(basename $fasta_path)
target_name=${fasta_file%.fa*}
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 unifold/inference.py \
--model_name=$model_name \
--param_path=$param_path \
--data_dir=$input_dir_base \
--target_name=$target_name \
--output_dir=$output_dir_base
#!/bin/bash
export DTKROOT=/opt/dtk-22.04.2
export AMDGPU_TARGETS="gfx906"
export ROCMVER=22.04.2
export DTK_HOME=/opt/dtk-22.04.2
export ROCM_PATH=${DTK_HOME}
export HIP_PATH=${DTK_HOME}/hip
export PATH=${DTK_HOME}/bin:${DTK_HOME}/llvm/bin:${DTK_HOME}/hip/bin:${DTK_HOME}/miopen/bin:$PATH
export LD_LIBRARY_PATH=${DTK_HOME}/lib:${DTK_HOME}/lib64:${DTK_HOME}/hip/lib:${DTK_HOME}/llvm/lib:${DTK_HOME}/miopen/lib:$LD_LIBRARY_PATH
export INCLUDE=${DTK_HOME}/include:${DTK_HOME}/hip/include:${DTK_HOME}/llvm/include:$INCLUDE
export C_INCLUDE_PATH=${DTK_HOME}/include:${DTK_HOME}/hip/include:${DTK_HOME}/llvm/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=${DTK_HOME}/include:$CPLUS_INCLUDE_PATH
export MIOPEN_SYSTEM_DB_PATH=${DTK_HOME}/miopen/share/miopen/db/
source /root/env.sh
cd /root/Uni-Fold-main
fasta_path=/root/Uni-Fold-main/data/H1036.fasta
input_dir_base=/root/Uni-Fold-main/data
output_dir_base=/root/Uni-Fold-main/data
database_dir=/alphafold/alphafold/
max_template_date=2020-05-01
model_name=multimer_ft #model_2_ft
param_path=/root/Uni-Fold-main/Alphafold/multimer.unifold.pt
echo "Starting homogeneous searching..."
python unifold/homo_search.py \
--fasta_path=$fasta_path \
--max_template_date=$max_template_date \
--output_dir=$output_dir_base \
--uniref90_database_path=$database_dir/uniref90/uniref90.fasta \
--mgnify_database_path=$database_dir/mgnify/mgy_clusters_2018_12.fa \
--bfd_database_path=$database_dir/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--uniclust30_database_path=$database_dir/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniprot_database_path=$database_dir/uniprot/uniprot.fasta \
--pdb_seqres_database_path=$database_dir/pdb_seqres/pdb_seqres.txt \
--template_mmcif_dir=$database_dir/pdb_mmcif/mmcif_files \
--obsolete_pdbs_path=$database_dir/pdb_mmcif/obsolete.dat \
--use_precomputed_msas=True
echo "Starting prediction..."
fasta_file=$(basename $fasta_path)
target_name=${fasta_file%.fa*}
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 unifold/inference.py \
--model_name=$model_name \
--param_path=$param_path \
--data_dir=$input_dir_base \
--target_name=$target_name \
--output_dir=$output_dir_base
fasta_path=$1
symmetry=$2
output_dir_base=$3
database_dir=$4
max_template_date=$5
param_path=$6
echo "Starting homogeneous searching..."
python unifold/homo_search.py \
--fasta_path=$fasta_path \
--max_template_date=$max_template_date \
--output_dir=$output_dir_base \
--uniref90_database_path=$database_dir/uniref90/uniref90.fasta \
--mgnify_database_path=$database_dir/mgnify/mgy_clusters_2018_12.fa \
--bfd_database_path=$database_dir/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--uniclust30_database_path=$database_dir/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniprot_database_path=$database_dir/uniprot/uniprot.fasta \
--pdb_seqres_database_path=$database_dir/pdb_seqres/pdb_seqres.txt \
--template_mmcif_dir=$database_dir/pdb_mmcif/mmcif_files \
--obsolete_pdbs_path=$database_dir/pdb_mmcif/obsolete.dat \
--use_precomputed_msas=True
echo "Starting prediction..."
fasta_file=$(basename $fasta_path)
target_name=${fasta_file%.fa*}
python unifold/inference_symmetry.py \
--symmetry=$symmetry \
--param_path=$param_path \
--data_dir=$output_dir_base \
--target_name=$target_name \
--output_dir=$output_dir_base
fasta_path=$1
output_dir_base=$2
database_dir=$3
max_template_date=$4
model_name=$5
param_path=$6
echo "Starting homogeneous searching..."
python unifold/homo_search.py \
--fasta_path=$fasta_path \
--max_template_date=$max_template_date \
--output_dir=$output_dir_base \
--uniref90_database_path=$database_dir/uniref90/uniref90.fasta \
--mgnify_database_path=$database_dir/mgnify/mgy_clusters_2018_12.fa \
--bfd_database_path=$database_dir/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--uniclust30_database_path=$database_dir/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniprot_database_path=$database_dir/uniprot/uniprot.fasta \
--pdb_seqres_database_path=$database_dir/pdb_seqres/pdb_seqres.txt \
--template_mmcif_dir=$database_dir/pdb_mmcif/mmcif_files \
--obsolete_pdbs_path=$database_dir/pdb_mmcif/obsolete.dat \
--use_precomputed_msas=True
echo "Starting prediction..."
fasta_file=$(basename $fasta_path)
target_name=${fasta_file%.fa*}
python unifold/inference.py \
--model_name=$model_name \
--param_path=$param_path \
--data_dir=$output_dir_base \
--target_name=$target_name \
--output_dir=$output_dir_base
import torch
import sys
from unifold.config import model_config
from unifold.modules.alphafold import AlphaFold
from scripts.translate_jax_params import (
import_jax_weights_,
)
load_ckpt=sys.argv[1]
save_ckpt=sys.argv[2]
model_name = sys.argv[3]
config = model_config(model_name)
model = AlphaFold(config)
import_jax_weights_(model, load_ckpt, version=model_name)
state_dict = model.state_dict()
save_state_dict = {}
save_state_dict["ema"] = {}
save_state_dict["extra_state"] = {}
save_state_dict["extra_state"]["train_iterator"] = {}
save_state_dict["extra_state"]["train_iterator"]["epoch"] = 1
update_state_dict = {"model." + k:state_dict[k] for k in state_dict}
save_state_dict["ema"]["params"] = update_state_dict
torch.save(save_state_dict, save_ckpt)
import torch
import sys
def openfold2unifold(model_states):
new_model_states = {}
mul_projs = {}
mul_gates = {}
for key, value in model_states.items():
new_key = key
if "msa_att_col._msa_att" in key:
new_key = new_key.replace("msa_att_col._msa_att", "msa_att_col")
if "extra_msa_stack.stack" in key:
new_key = new_key.replace("extra_msa_stack.stack", "extra_msa_stack")
if "tri_mul" in key:
if "linear_a_p" in key or "linear_b_p" in key:
new_key = key.replace("linear_a_p", "linear_ab_p").replace(
"linear_b_p", "linear_ab_p"
)
mul_projs[new_key] = 1
continue
if "linear_a_g" in key or "linear_b_g" in key:
new_key = key.replace("linear_a_g", "linear_ab_g").replace(
"linear_b_g", "linear_ab_g"
)
mul_gates[new_key] = 1
continue
if ".tm." in key:
new_key = new_key.replace(".tm.", ".pae.")
if ".core." in key:
new_key = new_key.replace("core." ,"")
new_model_states[new_key] = value
for key in mul_projs:
new_key = key
k1 = key.replace("linear_ab_p", "linear_a_p")
k2 = key.replace("linear_ab_p", "linear_b_p")
weight = torch.cat([model_states[k1], model_states[k2]], dim=0)
if ".core." in key:
new_key = new_key.replace("core." ,"")
new_model_states[new_key] = weight
for key in mul_gates:
new_key = key
k1 = key.replace("linear_ab_g", "linear_a_g")
k2 = key.replace("linear_ab_g", "linear_b_g")
weight = torch.cat([model_states[k1], model_states[k2]], dim=0)
if ".core." in key:
new_key = new_key.replace("core." ,"")
new_model_states[new_key] = weight
return new_model_states
load_ckpt=sys.argv[1]
save_ckpt=sys.argv[2]
state_dict = torch.load(load_ckpt)
state_dict = openfold2unifold(state_dict)
save_state_dict = {}
save_state_dict["ema"] = {}
save_state_dict["extra_state"] = {}
save_state_dict["extra_state"]["train_iterator"] = {}
save_state_dict["extra_state"]["train_iterator"]["epoch"] = 1
update_state_dict = {"model." + k:state_dict[k] for k in state_dict}
save_state_dict["ema"]["params"] = update_state_dict
torch.save(save_state_dict, save_ckpt)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment