Commit 319d9d8b authored by yuhai's avatar yuhai
Browse files

Initial commit

parents
{
"name": "nbdev_template-codespaces",
"dockerComposeFile": "docker-compose.yml",
"service": "watcher",
"settings": {"terminal.integrated.shell.linux": "/bin/bash"},
"mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ],
"forwardPorts": [4000, 8080],
"appPort": [4000, 8080],
"extensions": ["ms-python.python",
"ms-azuretools.vscode-docker"],
"runServices": ["notebook", "jekyll", "watcher"],
"postStartCommand": "pip install -e ."
}
.jekyll-cache/
Gemfile.lock
*.bak
.gitattributes
.last_checked
.gitconfig
*.bak
*.log
*~
~*
_tmp*
tmp*
tags
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.vscode
*.swp
# osx generated files
.DS_Store
.DS_Store?
.Trashes
ehthumbs.db
Thumbs.db
.idea
# pytest
.pytest_cache
# tools/trust-doc-nbs
docs_src/.last_checked
# symlinks to fastai
docs_src/fastai
tools/fastai
# link checker
checklink/cookies.txt
# .gitconfig is now autogenerated
.gitconfig
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# default_exp core"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Iterative_masking\n",
"\n",
"> Use MSA Transformer to generate synthetic protein sequences by masking iteratively the same MSA."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"\n",
"import numpy as np\n",
"import esm\n",
"from numba import njit, prange\n",
"import torch\n",
"from Bio import SeqIO\n",
"import itertools\n",
"from typing import List, Tuple\n",
"import string\n",
"from warnings import warn\n",
"\n",
"torch.set_grad_enabled(False)\n",
"\n",
"\n",
"# Iterative masking MSA-Transformer\n",
"class IM_MSA_Transformer:\n",
" \"\"\"Class that implement the Iterative masking algorithm\"\"\"\n",
" def __init__(self,\n",
" iterations=None,\n",
" p_mask=None,\n",
" filename=None,\n",
" num=None,\n",
" filepath=None):\n",
"\n",
" self.iterations = iterations # number of iterations used to generate the MSA\n",
" self.p_mask = p_mask # masking probability for the MSA generation\n",
" #---------------------------------------------------------------------------------------\n",
" # Delete lowercase characters and punctuations from a string (input fasta file)\n",
" self.deletekeys = dict.fromkeys(string.ascii_lowercase)\n",
" self.deletekeys[\".\"] = None\n",
" self.deletekeys[\"*\"] = None\n",
" self.translation = str.maketrans(self.deletekeys)\n",
" #---------------------------------------------------------------------------------------\n",
" if filename is None or num is None or filepath is None:\n",
" raise ValueError(\"`filepath`, `filename` and `num` must be specified to import the MSA\")\n",
" # Import Transformer model\n",
" self.msa_transformer, self.msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()\n",
" self.msa_transformer = self.msa_transformer.eval().cuda()\n",
" self.msa_batch_converter = self.msa_alphabet.get_batch_converter()\n",
" self.idx_list = self.msa_alphabet.tok_to_idx\n",
" print('MSA Transformer model imported')\n",
"\n",
" # If filename is an array then it's the input MSA\n",
" with torch.no_grad():\n",
" if isinstance(filename,np.ndarray):\n",
" self.msa_data = torch.Tensor(filename).type(torch.int64)\n",
" if len(filename.shape) != 3:\n",
" raise ValueError(\"`filename` should be an array with 3 axes\")\n",
" self.msa_batch_tokens = self.msa_data[:, :num[0], :]\n",
" print('Using MSA given in input')\n",
" else:\n",
" if len(num) != len(filename):\n",
" raise ValueError(\"`filename` and `num` must have the same length\")\n",
" #---------------------------------------------------------------------------------------\n",
" # Import MSAs\n",
" self.msa_data = []\n",
" for ff, nn in zip(filename, num):\n",
" self.msa_data += [self.read_msa(filepath + '/' + ff, nn)]\n",
" print('MSA Imported')\n",
" #---------------------------------------------------------------------------------------\n",
" # Create tokens starting from MSA\n",
" self.msa_batch_labels, self.msa_batch_strs, self.msa_batch_tokens = self.msa_batch_converter(\n",
" self.msa_data)\n",
" self.msa_data = (self.msa_batch_tokens).clone()\n",
" print(f'We are using batch MSAs of {num[0]} sequences')\n",
" self.msa_batch_tokens = self.msa_batch_tokens[:, :num[0], :]\n",
"\n",
" # Import tokens into cuda\n",
" self.msa_batch_tokens = self.msa_batch_tokens.cuda()\n",
"\n",
" print('MSA converted into tokens tensor of size and type:')\n",
" print(self.msa_batch_tokens.size(), self.msa_batch_tokens.dtype)\n",
"\n",
" #---------------------------------------------------------------------------------------\n",
" # Useful functions for handling string sequences\n",
" def read_sequence(self, filename: str) -> Tuple[str, str]:\n",
" \"\"\" Reads the first (reference) sequences from a fasta or MSA file.\"\"\"\n",
" record = next(SeqIO.parse(filename, \"fasta\"))\n",
" return record.description, str(record.seq)\n",
"\n",
" def remove_insertions(self, sequence: str) -> str:\n",
" \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n",
" return sequence.translate(self.translation)\n",
"\n",
" def read_msa(self, filename: str, nseq: int) -> List[Tuple[str, str]]:\n",
" \"\"\" Reads the first nseq sequences from an MSA file, automatically removes insertions.\"\"\"\n",
" tot = len([elem.id for elem in SeqIO.parse(filename, \"fasta\")])\n",
" print(f'Number of sequences in {filename}: ', tot)\n",
" return [\n",
" (record.description, self.remove_insertions(str(record.seq)))\n",
" for record in itertools.islice(SeqIO.parse(filename, \"fasta\"), tot)]\n",
"\n",
"#-----------------------------------------------------------------------------------------------------------------------\n",
"# USEFUL FUNCTIONS TO RUN THE MSA TRANSFORMER ON INFERENCE MODE\n",
"#-----------------------------------------------------------------------------------------------------------------------\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" def print_tokens(self, tokens=None):\n",
" \"\"\"\n",
" Outputs (on the cpu) the input `tokens` of the MSA, detaching them from the GPU.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" if tokens is None:\n",
" return ((self.msa_batch_tokens.detach().cpu()).to(\n",
" dtype=torch.int8)).numpy()\n",
" else:\n",
" return ((tokens.detach().cpu()).to(dtype=torch.int8)).numpy()\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" def compute_embeddings(self, tokens=None, lyrs=[12]):\n",
" \"\"\"\n",
" Starting from the `tokens`, use the model to predict their output embeddings and their associated\n",
" logits (when softmaxed they give the probability of each token)\n",
" `lyrs`: list of the layers from which extracting the embeddings (# 12 is the last layer)\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" if tokens is None:\n",
" tokens = self.msa_batch_tokens\n",
" if not tokens.is_cuda:\n",
" tokens = tokens.cuda()\n",
" results = self.msa_transformer(tokens,\n",
" repr_layers=lyrs,\n",
" return_contacts=False)\n",
" token_representations = results[\"representations\"][lyrs[0]].detach().cpu().numpy()\n",
" logits = results[\"logits\"].detach().cpu().numpy()\n",
" del results\n",
" return token_representations, logits\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" def compute_contacts(self, tokens=None):\n",
" \"\"\"\n",
" Starting from the `tokens`, use the model to predict the contact matrix of each MSA\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" if tokens is None:\n",
" tokens = self.msa_batch_tokens\n",
" if not tokens.is_cuda:\n",
" tokens = tokens.cuda()\n",
" msa_contacts = self.msa_transformer.predict_contacts(tokens).cpu()\n",
" return msa_contacts\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" @njit(parallel=True)\n",
" def Weights_Phylogeny(tkn, delta=0.8):\n",
" \"\"\"\n",
" Compute the Phylogeny weights of the sequences\n",
" `tkn`: the 2d array of tokens of one MSA, it should not have the first token (0)\n",
" and it should end before the start of the padding tokens (1).\n",
" `delta`: the phylogeny parameter\n",
" \"\"\"\n",
" depth, length = tkn.shape\n",
"\n",
" def _inner(seq1, seq2):\n",
" return np.sum(seq1 != seq2) / length\n",
"\n",
" weights = np.empty(depth, dtype=np.float64)\n",
" for i in prange(depth):\n",
" dists = np.empty(depth, dtype=np.float64)\n",
" for j in range(depth):\n",
" dists[j] = _inner(tkn[i], tkn[j])\n",
" within_neighbourhood = np.sum(dists < 1 - delta)\n",
" weights[i] = 1 / within_neighbourhood\n",
" return weights\n",
"\n",
"\n",
"#-----------------------------------------------------------------------------------------------------------------------\n",
"# USEFUL FUNCTIONS FOR THE MSA GENERATION WITH THE TRANSFORMER ON INFERENCE MODE\n",
"#-----------------------------------------------------------------------------------------------------------------------\n",
"\n",
"#-------------------------------------------------------------------------------------------------------------------\n",
"# Softmax of the logits tensor\n",
"\n",
" def softmax_tensor(self, x, axis, T=1):\n",
" \"\"\"\n",
" Compute softmax values for each sets of scores in `x` where `x` is the 4-d tensor of logits\n",
" and `T` is the sampling temperature.\n",
" \"\"\"\n",
" return torch.exp(x/T) / torch.sum(torch.exp(x/T), axis=axis)[:, :, :, None]\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" def generate_MSA(self, MSA_tokens, mask_idx=32, use_pdf=False, sample_all=False, T=1):\n",
" \"\"\"\n",
" Generate a new MSA by masking some entries of the original MSA and\n",
" re-predicting them through MSA Transformer.\n",
"\n",
" `MSA_tokens`: input tokens.\n",
"\n",
" `p_mask`: probability that an entry of the MSA is masked.\n",
"\n",
" `mask_idx`: masking index (as interpreted by the model), for MSA-Tr it's 32.\n",
"\n",
" `use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
" `sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
" `T`: Temperature of sampling from the pdf of output logits.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" if not MSA_tokens.is_cuda:\n",
" MSA_tokens = MSA_tokens.cuda()\n",
" mask = ((torch.rand(MSA_tokens.shape) > self.p_mask).type(\n",
" torch.uint8)).cuda()\n",
" masked_msa_tokens = MSA_tokens * mask + mask_idx * (1 - mask)\n",
" results = self.msa_transformer(masked_msa_tokens,\n",
" repr_layers=[12],\n",
" return_contacts=False)\n",
" msa_logits = self.softmax_tensor(x=results[\"logits\"], axis=3, T=T)\n",
" if use_pdf == False:\n",
" new_msa_tokens = torch.argmax(msa_logits, dim=3)\n",
" else:\n",
" Vals = torch.tensor([\n",
" 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,\n",
" 20, 21, 22, 23, 30\n",
" ],\n",
" dtype=torch.int64)\n",
" maxval = Vals[-1].cuda()\n",
" msa_logits = msa_logits[:, :, :, Vals]\n",
" msa_logits = msa_logits / (torch.sum(msa_logits,\n",
" axis=3)[:, :, :, None])\n",
" cum = torch.cumsum(msa_logits, dim=3)\n",
" idxs = torch.zeros_like(cum, dtype=torch.int64).cuda()\n",
" idxs1 = Vals[None, None, None, :].cuda()\n",
" idxs = idxs + idxs1\n",
" sample = (torch.rand(\n",
" (cum.shape[0], cum.shape[1], cum.shape[2]))).cuda()\n",
" idxs[torch.gt(sample[:, :, :, None], cum)] = 100\n",
" new_msa_tokens = torch.minimum(torch.amin(idxs, axis=3),\n",
" maxval)\n",
" del cum, idxs, idxs1, sample\n",
" if sample_all == False:\n",
" new_msa_tokens = MSA_tokens * mask + new_msa_tokens * (1 - mask)\n",
" new_msa_tokens[:, :, 0] = 0\n",
" del mask, masked_msa_tokens, results, msa_logits\n",
" return new_msa_tokens\n",
"\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" def NEW_MSA(self, use_pdf=False, simplified=False, sample_all=False, T=1):\n",
" \"\"\"\n",
" Generate a new MSA by iteratively calling the masked MSA generator defined in: `self.generate_MSA`.\n",
"\n",
" ---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n",
" or contacs), otherwise use `simplified`=True.\n",
"\n",
" The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n",
" the tokens should be saved. The last element of the array gives the maximum number of iterations that should be done.\n",
"\n",
" `use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
" `sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
" `T`: Temperature of sampling from the pdf of output logits.\n",
" \"\"\"\n",
" if self.iterations is None or self.p_mask is None:\n",
" raise ValueError(\n",
" \"Both `iterations` (numpy array) and `p_mask` (float) must be specified to generate a new MSA\"\n",
" )\n",
" max_iter = self.iterations[-1]\n",
" with torch.no_grad():\n",
" new_msa_tokens = self.msa_batch_tokens.clone()\n",
" all_tokens = torch.zeros(\n",
" (len(self.iterations), self.msa_batch_tokens.shape[0],\n",
" self.msa_batch_tokens.shape[1],\n",
" self.msa_batch_tokens.shape[2]),\n",
" dtype=torch.int64)\n",
" if simplified:\n",
" all_tokens = all_tokens.to(dtype=torch.int8)\n",
" if self.msa_alphabet.mask_idx != 32:\n",
" raise ValueError(\n",
" f\"The token used for masking is {self.msa_alphabet.mask_idx} instead of 32\"\n",
" )\n",
" # Iterate the MSA generation process\n",
" j = 0\n",
" for i in range(max_iter):\n",
" new_msa_tokens = self.generate_MSA(\n",
" MSA_tokens=new_msa_tokens,\n",
" mask_idx=self.msa_alphabet.mask_idx,\n",
" use_pdf=use_pdf, sample_all=sample_all, T=T)\n",
" if np.any((i + 1) == self.iterations):\n",
" # Save the tokens at the specified iterations\n",
" if simplified:\n",
" all_tokens[j,\n",
" ...] = (new_msa_tokens.clone().detach().cpu()).to(\n",
" dtype=torch.int8)\n",
" else:\n",
" all_tokens[j, ...] = new_msa_tokens.clone()\n",
" j += 1\n",
" del new_msa_tokens\n",
" if simplified:\n",
" return all_tokens.numpy()\n",
" else:\n",
" return all_tokens.cuda()\n",
"\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" def Batch_MSA(self, use_pdf=False, simplified=False, repetitions=2, sample_all=False, T=1, phylo=False):\n",
" \"\"\"\n",
" Generate a full MSA by calling with different input MSAs the iterative MSA generator defined\n",
" in: `self.NEW_MSA`.\n",
"\n",
" ---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n",
" or contacs), otherwise use `simplified`=True\n",
"\n",
" The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n",
" the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n",
"\n",
" `repetitions`: the number of times self.NEW_MSA() is repeated with a different input MSA.\n",
"\n",
" `use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
" `sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
" `T`: Temperature of sampling from the pdf of output logits.\n",
"\n",
" `phylo`: if True the start sequences are sampled from phylogeny weights instead of randomly.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" all_tokens = np.zeros(\n",
" (len(self.iterations), self.msa_batch_tokens.shape[0],\n",
" self.msa_batch_tokens.shape[1] * repetitions,\n",
" self.msa_batch_tokens.shape[2]),\n",
" dtype=np.int64)\n",
" if simplified:\n",
" all_tokens = all_tokens.astype('int8')\n",
" ALL_tokens = self.msa_data\n",
" depth = self.msa_batch_tokens.shape[1]\n",
" if repetitions * depth > ALL_tokens.shape[1]:\n",
" all_tokens = np.zeros(\n",
" (len(self.iterations), self.msa_batch_tokens.shape[0],\n",
" ALL_tokens.shape[1], self.msa_batch_tokens.shape[2]),\n",
" dtype=np.int64)\n",
"\n",
" if not phylo:\n",
" ALL_tokens = ALL_tokens[:, torch.randperm(ALL_tokens.shape[1]), :]\n",
" else:\n",
" _ = self.Weights_Phylogeny(ALL_tokens[0, :20, :], delta=0.8)\n",
" phylo_w = self.Weights_Phylogeny(ALL_tokens[0, :, :], delta=0.8)\n",
" indxs = torch.multinomial(phylo_w, ALL_tokens.shape[1], replacement=True)\n",
" ALL_tokens = ALL_tokens[:, indxs, :]\n",
" for i in range(repetitions):\n",
" ind = torch.arange(i * depth, (i + 1) * depth)\n",
" if (i + 1) * depth > ALL_tokens.shape[1]:\n",
" ind = torch.arange(i * depth, ALL_tokens.shape[1])\n",
" self.msa_batch_tokens = ALL_tokens[:, ind, :]\n",
" self.msa_batch_tokens = self.msa_batch_tokens.cuda()\n",
" all_tokens[:, :,\n",
" ind.numpy(), :] = self.NEW_MSA(use_pdf=use_pdf, simplified=simplified, sample_all=sample_all, T=T)\n",
" if (i + 1) * depth > ALL_tokens.shape[1]:\n",
" break\n",
"\n",
" if simplified:\n",
" return (ALL_tokens[:, :repetitions *\n",
" depth, :].numpy()).astype('int8'), all_tokens\n",
" else:\n",
" return ALL_tokens[:, :repetitions *\n",
" depth, :], torch.from_numpy(all_tokens).cuda()\n",
"\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
"\n",
" def generate_MSA_context(self, ancestor, context, mask_idx=32, use_pdf=False, sample_all=False, T=1):\n",
" \"\"\"\n",
" Generate a new sequence by masking some entries of the original ancestor sequence and\n",
" re-predicting them through the transformer model (mask only `ancestor`, not the `context`).\n",
"\n",
" `ancestor`: input sequence to be masked iteratively.\n",
"\n",
" `context`: context MSA (not masked).\n",
"\n",
" `p_mask`: probability that an entry of the MSA is masked.\n",
"\n",
" `mask_idx`: masking index (as interpreted by the model), for MSA-Tr it's 32.\n",
"\n",
" `use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
" `sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
" `T`: Temperature of sampling from the pdf of output logits.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
"\n",
" if not ancestor.is_cuda:\n",
" ancestor = ancestor.cuda()\n",
" if not context.is_cuda:\n",
" context = context.cuda()\n",
"\n",
" mask = ((torch.rand(ancestor.shape) > self.p_mask).type(torch.uint8)).cuda()\n",
" masked_ancestor = ancestor * mask + mask_idx * (1 - mask)\n",
" \n",
" masked_msa_tokens = torch.zeros((context.shape[0],\n",
" context.shape[1]+1,\n",
" context.shape[2]),\n",
" dtype=torch.int64).cuda()\n",
" masked_msa_tokens[0, 0, :] = masked_ancestor\n",
" masked_msa_tokens[:, 1:, :] = context\n",
"\n",
" results = self.msa_transformer(masked_msa_tokens,\n",
" repr_layers=[12],\n",
" return_contacts=False)\n",
" results1 = results[\"logits\"][:,0,:,:]\n",
" results1 = results1[:,None,:,:]\n",
" msa_logits = self.softmax_tensor(x=results1, axis=3, T=T)\n",
"\n",
" if use_pdf == False:\n",
" new_generation = torch.argmax(msa_logits, dim=3)[0,0,:]\n",
" else:\n",
" Vals = torch.tensor([\n",
" 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,\n",
" 20, 21, 22, 23, 30],dtype=torch.int64)\n",
" maxval = Vals[-1].cuda()\n",
" msa_logits = msa_logits[:, :, :, Vals]\n",
" msa_logits = msa_logits / (torch.sum(msa_logits,\n",
" axis=3)[:, :, :, None])\n",
" cum = torch.cumsum(msa_logits, dim=3)\n",
" idxs = torch.zeros_like(cum, dtype=torch.int64).cuda()\n",
" idxs1 = Vals[None, None, None, :].cuda()\n",
" idxs = idxs + idxs1\n",
" sample = (torch.rand(\n",
" (cum.shape[0], cum.shape[1], cum.shape[2]))).cuda()\n",
" idxs[torch.gt(sample[:, :, :, None], cum)] = 100\n",
" new_generation = torch.minimum(torch.amin(idxs, axis=3),\n",
" maxval)[0,0,:]\n",
" del cum, idxs, idxs1, sample\n",
" \n",
" if sample_all == False:\n",
" new_generation = ancestor * mask + new_generation * (1 - mask)\n",
" new_generation[0] = 0\n",
"\n",
" del mask, masked_msa_tokens, results, results1, msa_logits\n",
" return new_generation\n",
"\n",
" #-------------------------------------------------------------------------------------------------------------------\n",
" # Generate new sequence in a Linear tree by reiterating the function `generate_MSA_context()` starting from the sequence:\n",
" # `ancestor` (original sequence) and using the sequences in `context` as context MSA.\n",
" def Context_MSA(self, depth=None, ancestor=None, context=None, use_pdf=False, simplified=False, sample_all=False, print_all=True, T=1):\n",
" \"\"\"\n",
" Generates a new MSA with context-generation by iterating the masking on the original ancestor sequence\n",
" using: `self.generate_MSA_context`. It masks `ancestor` (original sequence) and uses the sequences in `context` as context MSA.\n",
" \n",
" ---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n",
" or contacs), otherwise use `simplified`=True\n",
"\n",
" The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n",
" the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n",
" If `print_all`=True then it saves the generated sequences at each iteration.\n",
"\n",
" `ancestor`: input sequence to be masked iteratively.\n",
"\n",
" `context`: context MSA (not masked).\n",
"\n",
" `use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
" `sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
" `T`: Temperature of sampling from the pdf of output logits.\n",
"\n",
" `depth`: number of generated sequences, if None the depth is the number of ancestor sequences.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" total_ran=False\n",
" if ancestor is None and context is None and depth is not None:\n",
" ALL_tokens = self.msa_data\n",
" ALL_tokens = ALL_tokens[:, torch.randperm(ALL_tokens.shape[1]), :]\n",
" ancestor = ALL_tokens[0,:depth,:]\n",
" ALL_tokens = ALL_tokens[:, torch.randperm(ALL_tokens.shape[1]), :]\n",
" context = ALL_tokens[:,:self.msa_batch_tokens.shape[1],:]\n",
" elif depth is None:\n",
" depth = ancestor.shape[0]\n",
" if isinstance(context,np.ndarray):\n",
" total_ran=False\n",
" elif context=='tot-ran':\n",
" total_ran=True\n",
" else:\n",
" print('ERROR, either you give depth or you give ancestor and context')\n",
"\n",
" all_tokens = torch.zeros((self.msa_batch_tokens.shape[0],\n",
" self.iterations[-1]+1,\n",
" depth,\n",
" ancestor.shape[1]),\n",
" dtype=torch.int64).cuda()\n",
"\n",
" ancestor = torch.from_numpy(ancestor).to(dtype=torch.int64)\n",
" if not total_ran:\n",
" context = torch.from_numpy(context).to(dtype=torch.int64)\n",
" if total_ran:\n",
" ALL_tokens = self.msa_data\n",
"\n",
" all_tokens[0, 0, :, :] = ancestor\n",
"\n",
" if simplified:\n",
" all_tokens = all_tokens.to(dtype=torch.int8)\n",
" if self.msa_alphabet.mask_idx != 32:\n",
" raise ValueError(\n",
" f\"The token used for masking is {self.msa_alphabet.mask_idx} instead of 32\"\n",
" )\n",
" \n",
" # Iterate the MSA generation tree\n",
" for j in range(depth):\n",
" new_ancestor = all_tokens[0, 0, j, :]\n",
" for i in range(1,self.iterations[-1]+1):\n",
" if total_ran:\n",
" context = (ALL_tokens[:, torch.randperm(ALL_tokens.shape[1])[:self.msa_batch_tokens.shape[1]], :]).cuda()\n",
" new_ancestor = self.generate_MSA_context(ancestor=new_ancestor,context=context, mask_idx=self.msa_alphabet.mask_idx, use_pdf=use_pdf, sample_all=sample_all, T=T)\n",
" if print_all:\n",
" all_tokens[0, i, j, :] = new_ancestor\n",
" if not print_all:\n",
" all_tokens[0, -1, j, :] = new_ancestor\n",
" # torch.cuda.empty_cache()\n",
"\n",
" if not print_all:\n",
" all_tokens = all_tokens[:,torch.tensor([-1]),:,:]\n",
"\n",
" if simplified:\n",
" return ((context.detach().cpu()).to(dtype=torch.int8)).numpy(), ((all_tokens.detach().cpu()).to(dtype=torch.int8)).numpy()\n",
" else:\n",
" return context.cuda(), all_tokens.cuda()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"<h2 id=\"IM_MSA_Transformer\" class=\"doc_header\"><code>class</code> <code>IM_MSA_Transformer</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h2>\n",
"\n",
"> <code>IM_MSA_Transformer</code>(**`iterations`**=*`None`*, **`p_mask`**=*`None`*, **`filename`**=*`None`*, **`num`**=*`None`*, **`filepath`**=*`None`*)\n",
"\n",
"Class that implement the Iterative masking algorithm"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(IM_MSA_Transformer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"<h4 id=\"IM_MSA_Transformer.Batch_MSA\" class=\"doc_header\"><code>IM_MSA_Transformer.Batch_MSA</code><a href=\"__main__.py#L303\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>IM_MSA_Transformer.Batch_MSA</code>(**`use_pdf`**=*`False`*, **`simplified`**=*`False`*, **`repetitions`**=*`2`*, **`sample_all`**=*`False`*, **`T`**=*`1`*, **`phylo`**=*`False`*)\n",
"\n",
"Generate a full MSA by calling with different input MSAs the iterative MSA generator defined\n",
"in: `self.NEW_MSA`.\n",
"\n",
"---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n",
" or contacs), otherwise use `simplified`=True\n",
"\n",
"The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n",
"the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n",
"\n",
"`repetitions`: the number of times self.NEW_MSA() is repeated with a different input MSA.\n",
"\n",
"`use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
"`sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
"`T`: Temperature of sampling from the pdf of output logits.\n",
"\n",
"`phylo`: if True the start sequences are sampled from phylogeny weights instead of randomly."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(IM_MSA_Transformer.Batch_MSA)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"<h4 id=\"IM_MSA_Transformer.Context_MSA\" class=\"doc_header\"><code>IM_MSA_Transformer.Context_MSA</code><a href=\"__main__.py#L448\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>IM_MSA_Transformer.Context_MSA</code>(**`depth`**=*`None`*, **`ancestor`**=*`None`*, **`context`**=*`None`*, **`use_pdf`**=*`False`*, **`simplified`**=*`False`*, **`sample_all`**=*`False`*, **`print_all`**=*`True`*, **`T`**=*`1`*)\n",
"\n",
"Generates a new MSA with context-generation by iterating the masking on the original ancestor sequence\n",
"using: `self.generate_MSA_context`. It masks `ancestor` (original sequence) and uses the sequences in `context` as context MSA.\n",
"\n",
"---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n",
" or contacs), otherwise use `simplified`=True\n",
"\n",
"The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n",
"the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n",
"If `print_all`=True then it saves the generated sequences at each iteration.\n",
"\n",
"`ancestor`: input sequence to be masked iteratively.\n",
"\n",
"`context`: context MSA (not masked).\n",
"\n",
"`use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
"`sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
"`T`: Temperature of sampling from the pdf of output logits.\n",
"\n",
"`depth`: number of generated sequences, if None the depth is the number of ancestor sequences."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(IM_MSA_Transformer.Context_MSA)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"import os\n",
"import pickle\n",
"from fastcore.script import *\n",
"\n",
"@call_parse\n",
"def gen_MSAs(filepath:Param(help='Path of the input directory',type=str,default='./'),\n",
" filename:Param(help='Name of the input file(s)',type=str,nargs='+',default=False),\n",
" new_dir:Param(help='Name of the output directory',type=str,default=False),\n",
" pdf:Param(help='Should I sample tokens from the pdf ? (bool)',type=bool_arg,default=False),\n",
" T:Param(help='Which is the sampling Temperature from the pdf ? (only when `pdf` is True)',type=float,default=1),\n",
" sample_all:Param(help='Should I sample all tokens or just the masked ones ? (True = sample all tokens)',type=bool_arg, default=False),\n",
" Iters:Param(help='Number of total iterations to generate the new tokens',type=int,default=10),\n",
" pmask:Param(help='Masking probability',type=float,default=0.1),\n",
" num:Param(help='Size of the batches MSAs which the MSA-Transformer receives as input',type=int,nargs='+',default=100),\n",
" depth:Param(help='Number of batches (of size num) that you want to generate',type=int,default=2),\n",
" generate:Param(help='How should I generate sequences ? False (=Batch generation) or Linear with context (=linear-ran/linear-tot-ran), `-ran` means that the context MSA is sampled randomly (once) while `-tot-ran` means that it is sampled randomly each time.',type=str, default=False),\n",
" print_all:Param(help='Should I print the MSA after each iteration ? (bool)',type=bool_arg,default=False),\n",
" range_vals:Param(help='First and last index of the sequences that you want to use as ancestors', type=int,nargs='+',default=False),\n",
" phylo_w:Param(help='Should I sample the starting sequences from the phylogeny weights ? (bool)',type=bool_arg,default=False)\n",
" ):\n",
" \"Generate a new MSA either with Batch generation of Context generation. It shuffles the initial MSA and uses different slices as batch MSAs\"\n",
"\n",
" # Create folder\n",
" path = os.getcwd()\n",
" path1 = new_dir\n",
" if new_dir is False:\n",
" path1 = filename[0][:-6]\n",
" try:\n",
" os.mkdir(path + \"/\" + path1)\n",
" except OSError:\n",
" print(\"Creation of the directory %s failed\" % (path + \"/\" + path1))\n",
" else:\n",
" print(\"Successfully created the directory %s \" % (path + \"/\" + path1))\n",
"\n",
" # Save Input MSA\n",
" print('Tokenize')\n",
" Class = IM_MSA_Transformer(filename=filename,\n",
" num=[-1],\n",
" filepath=filepath)\n",
" idx_list = Class.idx_list\n",
" old_tkn = Class.print_tokens()\n",
" a_file = open(path1 + \"/dictionary-tokens.pkl\", \"wb\")\n",
" pickle.dump(idx_list, a_file)\n",
" a_file.close()\n",
" np.save(path1 + \"/original-tokens.npy\", old_tkn[0])\n",
"\n",
" add_strs = \"\"\n",
" if pdf==True:\n",
" add_strs += f\"_pdf(T={round(T,3)})\"\n",
" print(\n",
" \"We are sampling new tokens from the pdf of logits and not taking the mode of the pdf\"\n",
" )\n",
" if T!=1 and pdf==False:\n",
" print('To sample with a Temperature you should use pdf=True, otherwise the result is the same')\n",
" if sample_all == False:\n",
" add_strs += \"_(only-masked-sampled)\"\n",
" if not generate==False:\n",
" add_strs += \"_\"+generate+\"_(context-\"+str(num[0])+\")\"\n",
" if phylo_w:\n",
" add_strs += \"_phylo-w\"\n",
"\n",
" print('Generate Class')\n",
" Class = IM_MSA_Transformer(iterations=np.array([Iters]),\n",
" p_mask=pmask,\n",
" filename=filename,\n",
" num=num,\n",
" filepath=filepath)\n",
"\n",
" print('Compute results from Class')\n",
" Class.iterations = np.array([Iters])\n",
" Class.p_mask = pmask\n",
"\n",
" if generate == False:\n",
" print('Generating MSA with same size as the original one')\n",
" old_T, new_T = Class.Batch_MSA(simplified=True,\n",
" repetitions=depth,\n",
" use_pdf=pdf, sample_all=sample_all, T=T, phylo=phylo_w)\n",
" NNN = min(num[0] * depth, old_T.shape[1])\n",
"\n",
" elif generate=='linear-ran' or generate=='linear-tot-ran':\n",
" print('Generate MSA with linear context generation')\n",
" orig_tkn = np.load(path + \"/\" + path1 + \"/original-tokens.npy\")\n",
" # select ancestor and context\n",
" np.random.seed(0)\n",
" indices = np.random.permutation(orig_tkn.shape[0])\n",
" indexes_context = indices[:num[0]]\n",
" indices = np.random.permutation(orig_tkn.shape[0])\n",
" if depth == -1:\n",
" ind_ancestor = indices\n",
" elif range_vals is False:\n",
" ind_ancestor = indices[:depth]\n",
" else:\n",
" if range_vals[1] == -1 :\n",
" ind_ancestor = indices[range_vals[0]:]\n",
" range_vals[1] = orig_tkn.shape[0]\n",
" else:\n",
" ind_ancestor = indices[range_vals[0]:range_vals[1]]\n",
" ancestor = orig_tkn[ind_ancestor,:]\n",
" context = orig_tkn[indexes_context,:][None,:,:]\n",
" if generate=='linear-tot-ran':\n",
" context = 'tot-ran'\n",
" old_T, new_T = Class.Context_MSA(None, ancestor, context, use_pdf=pdf, simplified=True, sample_all=sample_all, print_all=print_all, T=T)\n",
" if generate=='linear-tot-ran':\n",
" old_T = ancestor[None,:,:]\n",
" NNN = new_T.shape[2]\n",
" else:\n",
" print('ERROR: Select a generative process')\n",
"\n",
" # define the name of the directory to be created and create it\n",
" path2 = \"Generated\" + \"_iter-\" + str(\n",
" Iters) + \"_pmask-\" + str(pmask) + \"_seqs-\" + str(NNN) + add_strs\n",
" try:\n",
" os.mkdir(path + \"/\" + path1 + \"/\" + path2)\n",
" except OSError:\n",
" print(\"Creation of the directory %s failed\" % (path + \"/\" +\n",
" path1 + \"/\" + path2))\n",
" else:\n",
" print(\"Successfully created the directory %s \" % (path + \"/\" +\n",
" path1 + \"/\" + path2))\n",
"\n",
" # Save data\n",
" if generate == False or generate=='linear-tot-ran':\n",
" np.save(path1 + \"/\" + path2 + \"/shuffled-tokens.npy\", old_T[0])\n",
" else:\n",
" np.save(path1 + \"/\" + path2 + \"/context-tokens.npy\", old_T[0])\n",
" str_add = ''\n",
" if range_vals is not False:\n",
" str_add = '_range_indx_'+str(range_vals[0])+','+str(range_vals[1])\n",
" np.save(path1 + \"/\" + path2 + \"/new-tokens\"+str_add+\".npy\", new_T[0])\n",
"\n",
" return 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"<h4 id=\"gen_MSAs\" class=\"doc_header\"><code>gen_MSAs</code><a href=\"__main__.py#L6\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>gen_MSAs</code>(**`filepath`**:\"Path of the input directory\", **`filename`**:\"Name of the input file(s)\", **`new_dir`**:\"Name of the output directory\", **`pdf`**:\"Should I sample tokens from the pdf ? (bool)\", **`T`**:\"Which is the sampling Temperature from the pdf ? (only when `pdf` is True)\", **`sample_all`**:\"Should I sample all tokens or just the masked ones ? (True = sample all tokens)\", **`Iters`**:\"Number of total iterations to generate the new tokens\", **`pmask`**:\"Masking probability\", **`num`**:\"Size of the batches MSAs which the MSA-Transformer receives as input\", **`depth`**:\"Number of batches (of size num) that you want to generate\", **`generate`**:\"How should I generate sequences ? False (=Batch generation) or Linear with context (=linear-ran/linear-tot-ran), `-ran` means that the context MSA is sampled randomly (once) while `-tot-ran` means that it is sampled randomly each time.\", **`print_all`**:\"Should I print the MSA after each iteration ? (bool)\", **`range_vals`**:\"First and last index of the sequences that you want to use as ancestors\", **`phylo_w`**:\"Should I sample the starting sequences from the phylogeny weights ? (bool)\")\n",
"\n",
"Generate a new MSA either with Batch generation of Context generation. It shuffles the initial MSA and uses different slices as batch MSAs"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(gen_MSAs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build library"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_core.ipynb.\n",
"Converted index.ipynb.\n",
"converting: /home/damiano/Documents/Projects/Iterative_masking/00_core.ipynb\n",
"converting: /home/damiano/Documents/Projects/Iterative_masking/index.ipynb\n",
"converting /home/damiano/Documents/Projects/Iterative_masking/index.ipynb to README.md\n"
]
}
],
"source": [
"# hide\n",
"\n",
"# RUN THIS CELL EVERYTIME YOU CHANGE THE NOTEBOOK SO THAT IT BUILDS THE NEW LIBRARY \n",
"# --------> no need to run nbdev_build_lib on the terminal\n",
"from nbdev.export import notebook2script\n",
"notebook2script()\n",
"# !nbdev_clean_nbs\n",
"!nbdev_build_docs"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# How to contribute
## How to get started
Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it:
```
nbdev_install_git_hooks
```
## Did you find a bug?
* Ensure the bug was not already reported by searching on GitHub under Issues.
* If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
* Be sure to add the complete error messages.
#### Did you write a patch that fixes a bug?
* Open a new GitHub pull request with the patch.
* Ensure that your PR includes a test that fails without your patch, and pass with it.
* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
## PR submission guidelines
* Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
* Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
* Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
* Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
* If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
## Do you want to contribute to the documentation?
* Docs are automatically created from the notebooks in the nbs folder.
__version__ = "0.0.1"
# AUTOGENERATED BY NBDEV! DO NOT EDIT!
__all__ = ["index", "modules", "custom_doc_links", "git_url"]
index = {"IM_MSA_Transformer": "00_core.ipynb",
"gen_MSAs": "00_core.ipynb"}
modules = ["core.py"]
doc_url = "https://damiano-sg.github.io/Iterative_masking/"
git_url = "https://github.com/damiano-sg/Iterative_masking/tree/main/"
def custom_doc_links(name): return None
# AUTOGENERATED! DO NOT EDIT! File to edit: 00_core.ipynb (unless otherwise specified).
__all__ = ['IM_MSA_Transformer', 'gen_MSAs']
# Cell
import numpy as np
import esm
from numba import njit, prange
import torch
from Bio import SeqIO
import itertools
from typing import List, Tuple
import string
from warnings import warn
torch.set_grad_enabled(False)
# Iterative masking MSA-Transformer
class IM_MSA_Transformer:
"""Class that implement the Iterative masking algorithm"""
def __init__(self,
iterations=None,
p_mask=None,
filename=None,
num=None,
filepath=None):
self.iterations = iterations # number of iterations used to generate the MSA
self.p_mask = p_mask # masking probability for the MSA generation
self.num = num
#---------------------------------------------------------------------------------------
# Delete lowercase characters and punctuations from a string (input fasta file)
self.deletekeys = dict.fromkeys(string.ascii_lowercase)
self.deletekeys["."] = None
self.deletekeys["*"] = None
self.translation = str.maketrans(self.deletekeys)
#---------------------------------------------------------------------------------------
if filename is None or num is None or filepath is None:
raise ValueError(
"`filepath`, `filename` and `num` must be specified to import the MSA"
)
# Import Transformer model
self.msa_transformer, self.msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S(
)
self.msa_transformer = self.msa_transformer.eval().cuda()
self.msa_batch_converter = self.msa_alphabet.get_batch_converter()
self.idx_list = self.msa_alphabet.tok_to_idx
print('MSA Transformer model imported')
# If filename is an array then it's the input MSA
with torch.no_grad():
if isinstance(filename, np.ndarray):
self.msa_data = torch.Tensor(filename).type(torch.int64)
if len(filename.shape) != 3:
raise ValueError(
"`filename` should be an array with 3 axes")
self.msa_batch_tokens = self.msa_data[:, :num[0], :]
print('Using MSA given in input')
else:
if len(num) != len(filename):
raise ValueError(
"`filename` and `num` must have the same length")
#---------------------------------------------------------------------------------------
# Import MSAs
self.msa_data = []
for ff, nn in zip(filename, num):
self.msa_data += [self.read_msa(filepath + '/' + ff, nn)]
print('MSA Imported')
#---------------------------------------------------------------------------------------
# Create tokens starting from MSA
self.msa_batch_labels, self.msa_batch_strs, self.msa_batch_tokens = self.msa_batch_converter(
self.msa_data)
self.msa_data = (self.msa_batch_tokens).clone()
print(f'We are using batch MSAs of {num[0]} sequences')
self.msa_batch_tokens = self.msa_batch_tokens[:, :num[0], :]
# Import tokens into cuda
self.msa_batch_tokens = self.msa_batch_tokens.cuda()
print('MSA converted into tokens tensor of size and type:')
print(self.msa_batch_tokens.size(), self.msa_batch_tokens.dtype)
#---------------------------------------------------------------------------------------
# Useful functions for handling string sequences
def read_sequence(self, filename: str) -> Tuple[str, str]:
""" Reads the first (reference) sequences from a fasta or MSA file."""
record = next(SeqIO.parse(filename, "fasta"))
return record.description, str(record.seq)
def remove_insertions(self, sequence: str) -> str:
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
return sequence.translate(self.translation)
def read_msa(self, filename: str, nseq: int) -> List[Tuple[str, str]]:
""" Reads the first nseq sequences from an MSA file, automatically removes insertions."""
tot = len([elem.id for elem in SeqIO.parse(filename, "fasta")])
print(f'Number of sequences in {filename}: ', tot)
return [
(record.description, self.remove_insertions(str(record.seq)))
for record in itertools.islice(SeqIO.parse(filename, "fasta"), tot)
]
#-----------------------------------------------------------------------------------------------------------------------
# USEFUL FUNCTIONS TO RUN THE MSA TRANSFORMER ON INFERENCE MODE
#-----------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------
def print_tokens(self, tokens=None):
"""
Outputs (on the cpu) the input `tokens` of the MSA, detaching them from the GPU.
"""
with torch.no_grad():
if tokens is None:
return ((self.msa_batch_tokens.detach().cpu()).to(
dtype=torch.int8)).numpy()
else:
return ((tokens.detach().cpu()).to(dtype=torch.int8)).numpy()
#-------------------------------------------------------------------------------------------------------------------
def compute_embeddings(self, tokens=None, lyrs=[12]):
"""
Starting from the `tokens`, use the model to predict their output embeddings and their associated
logits (when softmaxed they give the probability of each token)
`lyrs`: list of the layers from which extracting the embeddings (# 12 is the last layer)
"""
with torch.no_grad():
if tokens is None:
tokens = self.msa_batch_tokens
if not tokens.is_cuda:
tokens = tokens.cuda()
results = self.msa_transformer(tokens,
repr_layers=lyrs,
return_contacts=False)
token_representations = results["representations"][
lyrs[0]].detach().cpu().numpy()
logits = results["logits"].detach().cpu().numpy()
del results
return token_representations, logits
#-------------------------------------------------------------------------------------------------------------------
def compute_contacts(self, tokens=None):
"""
Starting from the `tokens`, use the model to predict the contact matrix of each MSA
"""
with torch.no_grad():
if tokens is None:
tokens = self.msa_batch_tokens
if not tokens.is_cuda:
tokens = tokens.cuda()
msa_contacts = self.msa_transformer.predict_contacts(tokens).cpu()
return msa_contacts
#-------------------------------------------------------------------------------------------------------------------
@njit(parallel=True)
def Weights_Phylogeny(tkn, delta=0.8):
"""
Compute the Phylogeny weights of the sequences
`tkn`: the 2d array of tokens of one MSA, it should not have the first token (0)
and it should end before the start of the padding tokens (1).
`delta`: the phylogeny parameter
"""
depth, length = tkn.shape
def _inner(seq1, seq2):
return np.sum(seq1 != seq2) / length
weights = np.empty(depth, dtype=np.float64)
for i in prange(depth):
dists = np.empty(depth, dtype=np.float64)
for j in range(depth):
dists[j] = _inner(tkn[i], tkn[j])
within_neighbourhood = np.sum(dists < 1 - delta)
weights[i] = 1 / within_neighbourhood
return weights
#-----------------------------------------------------------------------------------------------------------------------
# USEFUL FUNCTIONS FOR THE MSA GENERATION WITH THE TRANSFORMER ON INFERENCE MODE
#-----------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------
# Softmax of the logits tensor
def softmax_tensor(self, x, axis, T=1):
"""
Compute softmax values for each sets of scores in `x` where `x` is the 4-d tensor of logits
and `T` is the sampling temperature.
"""
return torch.exp(x / T) / torch.sum(torch.exp(x / T),
axis=axis)[:, :, :, None]
#-------------------------------------------------------------------------------------------------------------------
def generate_MSA(self,
MSA_tokens,
mask_idx=32,
use_pdf=False,
sample_all=False,
T=1):
"""
Generate a new MSA by masking some entries of the original MSA and
re-predicting them through MSA Transformer.
`MSA_tokens`: input tokens.
`p_mask`: probability that an entry of the MSA is masked.
`mask_idx`: masking index (as interpreted by the model), for MSA-Tr it's 32.
`use_pdf`: if it's True the function sample the token from the logits pdf
instead of getting the argmax (greedy sampling).
`sample_all`: if True all the new tokens are obtained from the logits (both
the masked and the non masked), if False the non masked tokens
are left untouched and only the masked ones are changed.
`T`: Temperature of sampling from the pdf of output logits.
"""
with torch.no_grad():
if not MSA_tokens.is_cuda:
MSA_tokens = MSA_tokens.cuda()
mask = ((torch.rand(MSA_tokens.shape) > self.p_mask).type(
torch.uint8)).cuda()
masked_msa_tokens = MSA_tokens * mask + mask_idx * (1 - mask)
results = self.msa_transformer(masked_msa_tokens,
repr_layers=[12],
return_contacts=False)
msa_logits = self.softmax_tensor(x=results["logits"], axis=3, T=T)
if use_pdf == False:
new_msa_tokens = torch.argmax(msa_logits, dim=3)
else:
Vals = torch.tensor([
4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 30
],
dtype=torch.int64)
maxval = Vals[-1].cuda()
msa_logits = msa_logits[:, :, :, Vals]
msa_logits = msa_logits / (torch.sum(msa_logits,
axis=3)[:, :, :, None])
cum = torch.cumsum(msa_logits, dim=3)
idxs = torch.zeros_like(cum, dtype=torch.int64).cuda()
idxs1 = Vals[None, None, None, :].cuda()
idxs = idxs + idxs1
sample = (torch.rand(
(cum.shape[0], cum.shape[1], cum.shape[2]))).cuda()
idxs[torch.gt(sample[:, :, :, None], cum)] = 100
new_msa_tokens = torch.minimum(torch.amin(idxs, axis=3),
maxval)
del cum, idxs, idxs1, sample
if sample_all == False:
new_msa_tokens = MSA_tokens * mask + new_msa_tokens * (1 -
mask)
new_msa_tokens[:, :, 0] = 0
del mask, masked_msa_tokens, results, msa_logits
return new_msa_tokens
#-------------------------------------------------------------------------------------------------------------------
def NEW_MSA(self, use_pdf=False, simplified=False, sample_all=False, T=1):
"""
Generate a new MSA by iteratively calling the masked MSA generator defined in: `self.generate_MSA`.
---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed
or contacs), otherwise use `simplified`=True.
The variable `self.iterations` must be a numpy array which specifies when (at which iterations)
the tokens should be saved. The last element of the array gives the maximum number of iterations that should be done.
`use_pdf`: if it's True the function sample the token from the logits pdf
instead of getting the argmax (greedy sampling).
`sample_all`: if True all the new tokens are obtained from the logits (both
the masked and the non masked), if False the non masked tokens
are left untouched and only the masked ones are changed.
`T`: Temperature of sampling from the pdf of output logits.
"""
if self.iterations is None or self.p_mask is None:
raise ValueError(
"Both `iterations` (numpy array) and `p_mask` (float) must be specified to generate a new MSA"
)
max_iter = self.iterations[-1]
with torch.no_grad():
new_msa_tokens = self.msa_batch_tokens.clone()
all_tokens = torch.zeros(
(len(self.iterations), self.msa_batch_tokens.shape[0],
self.msa_batch_tokens.shape[1],
self.msa_batch_tokens.shape[2]),
dtype=torch.int64)
if simplified:
all_tokens = all_tokens.to(dtype=torch.int8)
if self.msa_alphabet.mask_idx != 32:
raise ValueError(
f"The token used for masking is {self.msa_alphabet.mask_idx} instead of 32"
)
# Iterate the MSA generation process
j = 0
for i in range(max_iter):
new_msa_tokens = self.generate_MSA(
MSA_tokens=new_msa_tokens,
mask_idx=self.msa_alphabet.mask_idx,
use_pdf=use_pdf,
sample_all=sample_all,
T=T)
if np.any((i + 1) == self.iterations):
# Save the tokens at the specified iterations
if simplified:
all_tokens[j, ...] = (
new_msa_tokens.clone().detach().cpu()).to(
dtype=torch.int8)
else:
all_tokens[j, ...] = new_msa_tokens.clone()
j += 1
del new_msa_tokens
if simplified:
return all_tokens.numpy()
else:
return all_tokens.cuda()
#-------------------------------------------------------------------------------------------------------------------
def Batch_MSA(self,
use_pdf=False,
simplified=False,
repetitions=2,
sample_all=False,
T=1,
phylo=False):
"""
Generate a full MSA by calling with different input MSAs the iterative MSA generator defined
in: `self.NEW_MSA`.
---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed
or contacs), otherwise use `simplified`=True
The variable `self.iterations` must be a numpy array which specifies when (at which iterations)
the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.
`repetitions`: the number of times self.NEW_MSA() is repeated with a different input MSA.
`use_pdf`: if it's True the function sample the token from the logits pdf
instead of getting the argmax (greedy sampling).
`sample_all`: if True all the new tokens are obtained from the logits (both
the masked and the non masked), if False the non masked tokens
are left untouched and only the masked ones are changed.
`T`: Temperature of sampling from the pdf of output logits.
`phylo`: if True the start sequences are sampled from phylogeny weights instead of randomly.
"""
with torch.no_grad():
all_tokens = np.zeros(
(len(self.iterations), self.msa_batch_tokens.shape[0],
self.msa_batch_tokens.shape[1] * repetitions,
self.msa_batch_tokens.shape[2]),
dtype=np.int64)
if simplified:
all_tokens = all_tokens.astype('int8')
ALL_tokens = self.msa_data
depth = self.num[0]
if repetitions * depth > ALL_tokens.shape[1]:
all_tokens = np.zeros(
(len(self.iterations), self.msa_batch_tokens.shape[0],
ALL_tokens.shape[1], self.msa_batch_tokens.shape[2]),
dtype=np.int64)
if not phylo:
ALL_tokens = ALL_tokens[:,
torch.randperm(ALL_tokens.shape[1]), :]
else:
_ = self.Weights_Phylogeny(ALL_tokens[0, :20, :], delta=0.8)
phylo_w = self.Weights_Phylogeny(ALL_tokens[0, :, :],
delta=0.8)
indxs = torch.multinomial(phylo_w,
ALL_tokens.shape[1],
replacement=True)
ALL_tokens = ALL_tokens[:, indxs, :]
for i in range(repetitions):
ind = torch.arange(i * depth, (i + 1) * depth)
if (i + 1) * depth > ALL_tokens.shape[1]:
ind = torch.arange(i * depth, ALL_tokens.shape[1])
self.msa_batch_tokens = ALL_tokens[:, ind, :]
self.msa_batch_tokens = self.msa_batch_tokens.cuda()
all_tokens[:, :, ind.numpy(), :] = self.NEW_MSA(
use_pdf=use_pdf,
simplified=simplified,
sample_all=sample_all,
T=T)
if (i + 1) * depth > ALL_tokens.shape[1]:
break
if simplified:
return (ALL_tokens[:, :repetitions *
depth, :].numpy()).astype('int8'), all_tokens
else:
return ALL_tokens[:, :repetitions *
depth, :], torch.from_numpy(all_tokens).cuda()
#-------------------------------------------------------------------------------------------------------------------
def generate_MSA_context(self,
ancestor,
context,
mask_idx=32,
use_pdf=False,
sample_all=False,
T=1):
"""
Generate a new sequence by masking some entries of the original ancestor sequence and
re-predicting them through the transformer model (mask only `ancestor`, not the `context`).
`ancestor`: input sequence to be masked iteratively.
`context`: context MSA (not masked).
`p_mask`: probability that an entry of the MSA is masked.
`mask_idx`: masking index (as interpreted by the model), for MSA-Tr it's 32.
`use_pdf`: if it's True the function sample the token from the logits pdf
instead of getting the argmax (greedy sampling).
`sample_all`: if True all the new tokens are obtained from the logits (both
the masked and the non masked), if False the non masked tokens
are left untouched and only the masked ones are changed.
`T`: Temperature of sampling from the pdf of output logits.
"""
with torch.no_grad():
if not ancestor.is_cuda:
ancestor = ancestor.cuda()
if not context.is_cuda:
context = context.cuda()
mask = ((torch.rand(ancestor.shape) > self.p_mask).type(
torch.uint8)).cuda()
masked_ancestor = ancestor * mask + mask_idx * (1 - mask)
masked_msa_tokens = torch.zeros(
(context.shape[0], context.shape[1] + 1, context.shape[2]),
dtype=torch.int64).cuda()
masked_msa_tokens[0, 0, :] = masked_ancestor
masked_msa_tokens[:, 1:, :] = context
results = self.msa_transformer(masked_msa_tokens,
repr_layers=[12],
return_contacts=False)
results1 = results["logits"][:, 0, :, :]
results1 = results1[:, None, :, :]
msa_logits = self.softmax_tensor(x=results1, axis=3, T=T)
if use_pdf == False:
new_generation = torch.argmax(msa_logits, dim=3)[0, 0, :]
else:
Vals = torch.tensor([
4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 30
],
dtype=torch.int64)
maxval = Vals[-1].cuda()
msa_logits = msa_logits[:, :, :, Vals]
msa_logits = msa_logits / (torch.sum(msa_logits,
axis=3)[:, :, :, None])
cum = torch.cumsum(msa_logits, dim=3)
idxs = torch.zeros_like(cum, dtype=torch.int64).cuda()
idxs1 = Vals[None, None, None, :].cuda()
idxs = idxs + idxs1
sample = (torch.rand(
(cum.shape[0], cum.shape[1], cum.shape[2]))).cuda()
idxs[torch.gt(sample[:, :, :, None], cum)] = 100
new_generation = torch.minimum(torch.amin(idxs, axis=3),
maxval)[0, 0, :]
del cum, idxs, idxs1, sample
if sample_all == False:
new_generation = ancestor * mask + new_generation * (1 - mask)
new_generation[0] = 0
del mask, masked_msa_tokens, results, results1, msa_logits
return new_generation
#-------------------------------------------------------------------------------------------------------------------
# Generate new sequence in a Linear tree by reiterating the function `generate_MSA_context()` starting from the sequence:
# `ancestor` (original sequence) and using the sequences in `context` as context MSA.
def Context_MSA(self,
depth=None,
ancestor=None,
context=None,
use_pdf=False,
simplified=False,
sample_all=False,
print_all=True,
T=1):
"""
Generates a new MSA with context-generation by iterating the masking on the original ancestor sequence
using: `self.generate_MSA_context`. It masks `ancestor` (original sequence) and uses the sequences in `context` as context MSA.
---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed
or contacs), otherwise use `simplified`=True
The variable `self.iterations` must be a numpy array which specifies when (at which iterations)
the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.
If `print_all`=True then it saves the generated sequences at each iteration.
`ancestor`: input sequence to be masked iteratively.
`context`: context MSA (not masked).
`use_pdf`: if it's True the function sample the token from the logits pdf
instead of getting the argmax (greedy sampling).
`sample_all`: if True all the new tokens are obtained from the logits (both
the masked and the non masked), if False the non masked tokens
are left untouched and only the masked ones are changed.
`T`: Temperature of sampling from the pdf of output logits.
`depth`: number of generated sequences, if None the depth is the number of ancestor sequences.
"""
with torch.no_grad():
total_ran = False
if ancestor is None and context is None and depth is not None:
ALL_tokens = self.msa_data
ALL_tokens = ALL_tokens[:,
torch.randperm(ALL_tokens.shape[1]), :]
ancestor = ALL_tokens[0, :depth, :]
ALL_tokens = ALL_tokens[:,
torch.randperm(ALL_tokens.shape[1]), :]
context = ALL_tokens[:, :self.msa_batch_tokens.shape[1], :]
elif depth is None:
depth = ancestor.shape[0]
if isinstance(context, np.ndarray):
total_ran = False
elif context == 'tot-ran':
total_ran = True
else:
print(
'ERROR, either you give depth or you give ancestor and context'
)
all_tokens = torch.zeros(
(self.msa_batch_tokens.shape[0], self.iterations[-1] + 1,
depth, ancestor.shape[1]),
dtype=torch.int64).cuda()
ancestor = torch.from_numpy(ancestor).to(dtype=torch.int64)
if not total_ran:
context = torch.from_numpy(context).to(dtype=torch.int64)
if total_ran:
ALL_tokens = self.msa_data
all_tokens[0, 0, :, :] = ancestor
if simplified:
all_tokens = all_tokens.to(dtype=torch.int8)
if self.msa_alphabet.mask_idx != 32:
raise ValueError(
f"The token used for masking is {self.msa_alphabet.mask_idx} instead of 32"
)
# Iterate the MSA generation tree
for j in range(depth):
new_ancestor = all_tokens[0, 0, j, :]
for i in range(1, self.iterations[-1] + 1):
if total_ran:
context = (
ALL_tokens[:,
torch.randperm(ALL_tokens.shape[1]
)[:self.msa_batch_tokens.
shape[1]], :]).cuda()
new_ancestor = self.generate_MSA_context(
ancestor=new_ancestor,
context=context,
mask_idx=self.msa_alphabet.mask_idx,
use_pdf=use_pdf,
sample_all=sample_all,
T=T)
if print_all:
all_tokens[0, i, j, :] = new_ancestor
if not print_all:
all_tokens[0, -1, j, :] = new_ancestor
# torch.cuda.empty_cache()
if not print_all:
all_tokens = all_tokens[:, torch.tensor([-1]), :, :]
if simplified:
return ((context.detach().cpu()).to(
dtype=torch.int8)).numpy(), ((all_tokens.detach().cpu()).to(
dtype=torch.int8)).numpy()
else:
return context.cuda(), all_tokens.cuda()
# Cell
import os
import pickle
from fastcore.script import *
@call_parse
def gen_MSAs(filepath: Param(
help='Path of the input directory', type=str, default='./'
), filename: Param(
help='Name of the input file(s)', type=str, nargs='+', default=False
), new_dir: Param(
help='Name of the output directory', type=str, default=False
), pdf: Param(
help='Should I sample tokens from the pdf ? (bool)',
type=bool_arg,
default=False
), T: Param(
help=
'Which is the sampling Temperature from the pdf ? (only when `pdf` is True)',
type=float,
default=1
), sample_all: Param(
help=
'Should I sample all tokens or just the masked ones ? (True = sample all tokens)',
type=bool_arg,
default=False
), Iters: Param(
help='Number of total iterations to generate the new tokens',
type=int,
default=10
), pmask: Param(
help='Masking probability',
type=float,
default=0.1
), num: Param(
help='Size of the batches MSAs which the MSA-Transformer receives as input',
type=int,
nargs='+',
default=100
), depth: Param(
help='Number of batches (of size num) that you want to generate',
type=int,
default=2
), generate: Param(
help=
'How should I generate sequences ? False (=Batch generation) or Linear with context (=linear-ran/linear-tot-ran), `-ran` means that the context MSA is sampled randomly (once) while `-tot-ran` means that it is sampled randomly each time.',
type=str,
default=False
), print_all: Param(
help='Should I print the MSA after each iteration ? (bool)',
type=bool_arg,
default=False
), range_vals: Param(
help=
'First and last index of the sequences that you want to use as ancestors',
type=int,
nargs='+',
default=False
), phylo_w: Param(
help=
'Should I sample the starting sequences from the phylogeny weights ? (bool)',
type=bool_arg,
default=False)):
"Generate a new MSA either with Batch generation of Context generation. It shuffles the initial MSA and uses different slices as batch MSAs"
# Create folder
path = os.getcwd()
path1 = new_dir
if new_dir is False:
path1 = filename[0][:-6]
try:
os.mkdir(path + "/" + path1)
except OSError:
print("Creation of the directory %s failed" % (path + "/" + path1))
else:
print("Successfully created the directory %s " % (path + "/" + path1))
# Save Input MSA
print('Tokenize')
Class = IM_MSA_Transformer(filename=filename, num=[-1], filepath=filepath)
idx_list = Class.idx_list
old_tkn = Class.print_tokens()
a_file = open(path1 + "/dictionary-tokens.pkl", "wb")
pickle.dump(idx_list, a_file)
a_file.close()
np.save(path1 + "/original-tokens.npy", old_tkn[0])
add_strs = ""
if pdf == True:
add_strs += f"_pdf(T={round(T,3)})"
print(
"We are sampling new tokens from the pdf of logits and not taking the mode of the pdf"
)
if T != 1 and pdf == False:
print(
'To sample with a Temperature you should use pdf=True, otherwise the result is the same'
)
if sample_all == False:
add_strs += "_(only-masked-sampled)"
if not generate == False:
add_strs += "_" + generate + "_(context-" + str(num[0]) + ")"
if phylo_w:
add_strs += "_phylo-w"
print('Generate Class')
Class = IM_MSA_Transformer(iterations=np.array([Iters]),
p_mask=pmask,
filename=filename,
num=num,
filepath=filepath)
print('Compute results from Class')
Class.iterations = np.array([Iters])
Class.p_mask = pmask
if generate == False:
print('Generating MSA with same size as the original one')
old_T, new_T = Class.Batch_MSA(simplified=True,
repetitions=depth,
use_pdf=pdf,
sample_all=sample_all,
T=T,
phylo=phylo_w)
NNN = min(num[0] * depth, old_T.shape[1])
elif generate == 'linear-ran' or generate == 'linear-tot-ran':
print('Generate MSA with linear context generation')
orig_tkn = np.load(path + "/" + path1 + "/original-tokens.npy")
# select ancestor and context
np.random.seed(0)
indices = np.random.permutation(orig_tkn.shape[0])
indexes_context = indices[:num[0]]
indices = np.random.permutation(orig_tkn.shape[0])
if depth == -1:
ind_ancestor = indices
elif range_vals is False:
ind_ancestor = indices[:depth]
else:
if range_vals[1] == -1:
ind_ancestor = indices[range_vals[0]:]
range_vals[1] = orig_tkn.shape[0]
else:
ind_ancestor = indices[range_vals[0]:range_vals[1]]
ancestor = orig_tkn[ind_ancestor, :]
context = orig_tkn[indexes_context, :][None, :, :]
if generate == 'linear-tot-ran':
context = 'tot-ran'
old_T, new_T = Class.Context_MSA(None,
ancestor,
context,
use_pdf=pdf,
simplified=True,
sample_all=sample_all,
print_all=print_all,
T=T)
if generate == 'linear-tot-ran':
old_T = ancestor[None, :, :]
NNN = new_T.shape[2]
else:
print('ERROR: Select a generative process')
# define the name of the directory to be created and create it
path2 = "Generated" + "_iter-" + str(Iters) + "_pmask-" + str(
pmask) + "_seqs-" + str(NNN) + add_strs
try:
os.mkdir(path + "/" + path1 + "/" + path2)
except OSError:
print("Creation of the directory %s failed" %
(path + "/" + path1 + "/" + path2))
else:
print("Successfully created the directory %s " %
(path + "/" + path1 + "/" + path2))
# Save data
if generate == False or generate == 'linear-tot-ran':
np.save(path1 + "/" + path2 + "/shuffled-tokens.npy", old_T[0])
else:
np.save(path1 + "/" + path2 + "/context-tokens.npy", old_T[0])
str_add = ''
if range_vals is not False:
str_add = '_range_indx_' + str(range_vals[0]) + ',' + str(
range_vals[1])
np.save(path1 + "/" + path2 + "/new-tokens" + str_add + ".npy", new_T[0])
return 1
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
import os
import pickle
from fastcore.script import *
from MSA import *
@call_parse
def gen_MSAs(filepath:Param(help='Path of the input directory',type=str,default='/Iterative_masking-master/examples'),
filename:Param(help='Name of the input file(s)',type=str,nargs='+',default=["PF00072.fasta"]),
new_dir:Param(help='Name of the output directory',type=str,default='results_new'),
pdf:Param(help='Should I sample tokens from the pdf ? (bool)',type=bool_arg,default=False),
T:Param(help='Which is the sampling Temperature from the pdf ? (only when `pdf` is True)',type=float,default=1),
sample_all:Param(help='Should I sample all tokens or just the masked ones ? (True = sample all tokens)',type=bool_arg, default=False),
Iters:Param(help='Number of total iterations to generate the new tokens',type=int,default=20),
pmask:Param(help='Masking probability',type=float,default=0.1),
num:Param(help='Size of the batches MSAs which the MSA-Transformer receives as input',type=int,nargs='+',default=[10]),
depth:Param(help='Number of batches (of size num) that you want to generate',type=int,default=10),
generate:Param(help='How should I generate sequences ? False (=Batch generation) or Linear with context (=linear-ran/linear-tot-ran), `-ran` means that the context MSA is sampled randomly (once) while `-tot-ran` means that it is sampled randomly each time.',type=str, default=False),
print_all:Param(help='Should I print the MSA after each iteration ? (bool)',type=bool_arg,default=False),
range_vals:Param(help='First and last index of the sequences that you want to use as ancestors', type=int,nargs='+',default=False),
phylo_w:Param(help='Should I sample the starting sequences from the phylogeny weights ? (bool)',type=bool_arg,default=False)
):
"Generate a new MSA either with Batch generation of Context generation. It shuffles the initial MSA and uses different slices as batch MSAs"
# Create folder
path = os.getcwd()
path1 = new_dir
if new_dir is False:
path1 = filename[0][:-6]
try:
os.mkdir(path + "/" + path1)
except OSError:
print("Creation of the directory %s failed" % (path + "/" + path1))
else:
print("Successfully created the directory %s " % (path + "/" + path1))
# Save Input MSA
print('Tokenize')
Class = IM_MSA_Transformer(filename=filename,
num=num,
filepath=filepath)
idx_list = Class.idx_list
old_tkn = Class.print_tokens()
a_file = open(path1 + "/dictionary-tokens.pkl", "wb")
pickle.dump(idx_list, a_file)
a_file.close()
np.save(path1 + "/original-tokens.npy", old_tkn[0])
add_strs = ""
if pdf==True:
add_strs += f"_pdf(T={round(T,3)})"
print(
"We are sampling new tokens from the pdf of logits and not taking the mode of the pdf"
)
if T!=1 and pdf==False:
print('To sample with a Temperature you should use pdf=True, otherwise the result is the same')
if sample_all == False:
add_strs += "_(only-masked-sampled)"
if not generate==False:
add_strs += "_"+generate+"_(context-"+str(num[0])+")"
if phylo_w:
add_strs += "_phylo-w"
print('Generate Class')
Class = IM_MSA_Transformer(iterations=np.array([Iters]),
p_mask=pmask,
filename=filename,
num=num,
filepath=filepath)
print('Compute results from Class')
Class.iterations = np.array([Iters])
Class.p_mask = pmask
if generate == False:
print('Generating MSA with same size as the original one')
old_T, new_T = Class.Batch_MSA(simplified=True,
repetitions=depth,
use_pdf=pdf, sample_all=sample_all, T=T, phylo=phylo_w)
NNN = min(num[0] * depth, old_T.shape[1])
elif generate=='linear-ran' or generate=='linear-tot-ran':
print('Generate MSA with linear context generation')
orig_tkn = np.load(path + "/" + path1 + "/original-tokens.npy")
# select ancestor and context
np.random.seed(0)
indices = np.random.permutation(orig_tkn.shape[0])
indexes_context = indices[:num[0]]
indices = np.random.permutation(orig_tkn.shape[0])
if depth == -1:
ind_ancestor = indices
elif range_vals is False:
ind_ancestor = indices[:depth]
else:
if range_vals[1] == -1 :
ind_ancestor = indices[range_vals[0]:]
range_vals[1] = orig_tkn.shape[0]
else:
ind_ancestor = indices[range_vals[0]:range_vals[1]]
ancestor = orig_tkn[ind_ancestor,:]
context = orig_tkn[indexes_context,:][None,:,:]
if generate=='linear-tot-ran':
context = 'tot-ran'
old_T, new_T = Class.Context_MSA(None, ancestor, context, use_pdf=pdf, simplified=True, sample_all=sample_all, print_all=print_all, T=T)
if generate=='linear-tot-ran':
old_T = ancestor[None,:,:]
NNN = new_T.shape[2]
else:
print('ERROR: Select a generative process')
# define the name of the directory to be created and create it
path2 = "Generated" + "_iter-" + str(
Iters) + "_pmask-" + str(pmask) + "_seqs-" + str(NNN) + add_strs
try:
os.mkdir(path + "/" + path1 + "/" + path2)
except OSError:
print("Creation of the directory %s failed" % (path + "/" +
path1 + "/" + path2))
else:
print("Successfully created the directory %s " % (path + "/" +
path1 + "/" + path2))
# Save data
if generate == False or generate=='linear-tot-ran':
np.save(path1 + "/" + path2 + "/shuffled-tokens.npy", old_T[0])
else:
np.save(path1 + "/" + path2 + "/context-tokens.npy", old_T[0])
str_add = ''
if range_vals is not False:
str_add = '_range_indx_'+str(range_vals[0])+','+str(range_vals[1])
np.save(path1 + "/" + path2 + "/new-tokens"+str_add+".npy", new_T[0])
return 1
print(show_doc(gen_MSAs))
File added
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment