{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp core" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Iterative_masking\n", "\n", "> Use MSA Transformer to generate synthetic protein sequences by masking iteratively the same MSA." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "\n", "import numpy as np\n", "import esm\n", "from numba import njit, prange\n", "import torch\n", "from Bio import SeqIO\n", "import itertools\n", "from typing import List, Tuple\n", "import string\n", "from warnings import warn\n", "\n", "torch.set_grad_enabled(False)\n", "\n", "\n", "# Iterative masking MSA-Transformer\n", "class IM_MSA_Transformer:\n", " \"\"\"Class that implement the Iterative masking algorithm\"\"\"\n", " def __init__(self,\n", " iterations=None,\n", " p_mask=None,\n", " filename=None,\n", " num=None,\n", " filepath=None):\n", "\n", " self.iterations = iterations # number of iterations used to generate the MSA\n", " self.p_mask = p_mask # masking probability for the MSA generation\n", " #---------------------------------------------------------------------------------------\n", " # Delete lowercase characters and punctuations from a string (input fasta file)\n", " self.deletekeys = dict.fromkeys(string.ascii_lowercase)\n", " self.deletekeys[\".\"] = None\n", " self.deletekeys[\"*\"] = None\n", " self.translation = str.maketrans(self.deletekeys)\n", " #---------------------------------------------------------------------------------------\n", " if filename is None or num is None or filepath is None:\n", " raise ValueError(\"`filepath`, `filename` and `num` must be specified to import the MSA\")\n", " # Import Transformer model\n", " self.msa_transformer, self.msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()\n", " self.msa_transformer = self.msa_transformer.eval().cuda()\n", " self.msa_batch_converter = self.msa_alphabet.get_batch_converter()\n", " self.idx_list = self.msa_alphabet.tok_to_idx\n", " print('MSA Transformer model imported')\n", "\n", " # If filename is an array then it's the input MSA\n", " with torch.no_grad():\n", " if isinstance(filename,np.ndarray):\n", " self.msa_data = torch.Tensor(filename).type(torch.int64)\n", " if len(filename.shape) != 3:\n", " raise ValueError(\"`filename` should be an array with 3 axes\")\n", " self.msa_batch_tokens = self.msa_data[:, :num[0], :]\n", " print('Using MSA given in input')\n", " else:\n", " if len(num) != len(filename):\n", " raise ValueError(\"`filename` and `num` must have the same length\")\n", " #---------------------------------------------------------------------------------------\n", " # Import MSAs\n", " self.msa_data = []\n", " for ff, nn in zip(filename, num):\n", " self.msa_data += [self.read_msa(filepath + '/' + ff, nn)]\n", " print('MSA Imported')\n", " #---------------------------------------------------------------------------------------\n", " # Create tokens starting from MSA\n", " self.msa_batch_labels, self.msa_batch_strs, self.msa_batch_tokens = self.msa_batch_converter(\n", " self.msa_data)\n", " self.msa_data = (self.msa_batch_tokens).clone()\n", " print(f'We are using batch MSAs of {num[0]} sequences')\n", " self.msa_batch_tokens = self.msa_batch_tokens[:, :num[0], :]\n", "\n", " # Import tokens into cuda\n", " self.msa_batch_tokens = self.msa_batch_tokens.cuda()\n", "\n", " print('MSA converted into tokens tensor of size and type:')\n", " print(self.msa_batch_tokens.size(), self.msa_batch_tokens.dtype)\n", "\n", " #---------------------------------------------------------------------------------------\n", " # Useful functions for handling string sequences\n", " def read_sequence(self, filename: str) -> Tuple[str, str]:\n", " \"\"\" Reads the first (reference) sequences from a fasta or MSA file.\"\"\"\n", " record = next(SeqIO.parse(filename, \"fasta\"))\n", " return record.description, str(record.seq)\n", "\n", " def remove_insertions(self, sequence: str) -> str:\n", " \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n", " return sequence.translate(self.translation)\n", "\n", " def read_msa(self, filename: str, nseq: int) -> List[Tuple[str, str]]:\n", " \"\"\" Reads the first nseq sequences from an MSA file, automatically removes insertions.\"\"\"\n", " tot = len([elem.id for elem in SeqIO.parse(filename, \"fasta\")])\n", " print(f'Number of sequences in {filename}: ', tot)\n", " return [\n", " (record.description, self.remove_insertions(str(record.seq)))\n", " for record in itertools.islice(SeqIO.parse(filename, \"fasta\"), tot)]\n", "\n", "#-----------------------------------------------------------------------------------------------------------------------\n", "# USEFUL FUNCTIONS TO RUN THE MSA TRANSFORMER ON INFERENCE MODE\n", "#-----------------------------------------------------------------------------------------------------------------------\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " def print_tokens(self, tokens=None):\n", " \"\"\"\n", " Outputs (on the cpu) the input `tokens` of the MSA, detaching them from the GPU.\n", " \"\"\"\n", " with torch.no_grad():\n", " if tokens is None:\n", " return ((self.msa_batch_tokens.detach().cpu()).to(\n", " dtype=torch.int8)).numpy()\n", " else:\n", " return ((tokens.detach().cpu()).to(dtype=torch.int8)).numpy()\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " def compute_embeddings(self, tokens=None, lyrs=[12]):\n", " \"\"\"\n", " Starting from the `tokens`, use the model to predict their output embeddings and their associated\n", " logits (when softmaxed they give the probability of each token)\n", " `lyrs`: list of the layers from which extracting the embeddings (# 12 is the last layer)\n", " \"\"\"\n", " with torch.no_grad():\n", " if tokens is None:\n", " tokens = self.msa_batch_tokens\n", " if not tokens.is_cuda:\n", " tokens = tokens.cuda()\n", " results = self.msa_transformer(tokens,\n", " repr_layers=lyrs,\n", " return_contacts=False)\n", " token_representations = results[\"representations\"][lyrs[0]].detach().cpu().numpy()\n", " logits = results[\"logits\"].detach().cpu().numpy()\n", " del results\n", " return token_representations, logits\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " def compute_contacts(self, tokens=None):\n", " \"\"\"\n", " Starting from the `tokens`, use the model to predict the contact matrix of each MSA\n", " \"\"\"\n", " with torch.no_grad():\n", " if tokens is None:\n", " tokens = self.msa_batch_tokens\n", " if not tokens.is_cuda:\n", " tokens = tokens.cuda()\n", " msa_contacts = self.msa_transformer.predict_contacts(tokens).cpu()\n", " return msa_contacts\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " @njit(parallel=True)\n", " def Weights_Phylogeny(tkn, delta=0.8):\n", " \"\"\"\n", " Compute the Phylogeny weights of the sequences\n", " `tkn`: the 2d array of tokens of one MSA, it should not have the first token (0)\n", " and it should end before the start of the padding tokens (1).\n", " `delta`: the phylogeny parameter\n", " \"\"\"\n", " depth, length = tkn.shape\n", "\n", " def _inner(seq1, seq2):\n", " return np.sum(seq1 != seq2) / length\n", "\n", " weights = np.empty(depth, dtype=np.float64)\n", " for i in prange(depth):\n", " dists = np.empty(depth, dtype=np.float64)\n", " for j in range(depth):\n", " dists[j] = _inner(tkn[i], tkn[j])\n", " within_neighbourhood = np.sum(dists < 1 - delta)\n", " weights[i] = 1 / within_neighbourhood\n", " return weights\n", "\n", "\n", "#-----------------------------------------------------------------------------------------------------------------------\n", "# USEFUL FUNCTIONS FOR THE MSA GENERATION WITH THE TRANSFORMER ON INFERENCE MODE\n", "#-----------------------------------------------------------------------------------------------------------------------\n", "\n", "#-------------------------------------------------------------------------------------------------------------------\n", "# Softmax of the logits tensor\n", "\n", " def softmax_tensor(self, x, axis, T=1):\n", " \"\"\"\n", " Compute softmax values for each sets of scores in `x` where `x` is the 4-d tensor of logits\n", " and `T` is the sampling temperature.\n", " \"\"\"\n", " return torch.exp(x/T) / torch.sum(torch.exp(x/T), axis=axis)[:, :, :, None]\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " def generate_MSA(self, MSA_tokens, mask_idx=32, use_pdf=False, sample_all=False, T=1):\n", " \"\"\"\n", " Generate a new MSA by masking some entries of the original MSA and\n", " re-predicting them through MSA Transformer.\n", "\n", " `MSA_tokens`: input tokens.\n", "\n", " `p_mask`: probability that an entry of the MSA is masked.\n", "\n", " `mask_idx`: masking index (as interpreted by the model), for MSA-Tr it's 32.\n", "\n", " `use_pdf`: if it's True the function sample the token from the logits pdf \n", " instead of getting the argmax (greedy sampling).\n", "\n", " `sample_all`: if True all the new tokens are obtained from the logits (both\n", " the masked and the non masked), if False the non masked tokens\n", " are left untouched and only the masked ones are changed.\n", "\n", " `T`: Temperature of sampling from the pdf of output logits.\n", " \"\"\"\n", " with torch.no_grad():\n", " if not MSA_tokens.is_cuda:\n", " MSA_tokens = MSA_tokens.cuda()\n", " mask = ((torch.rand(MSA_tokens.shape) > self.p_mask).type(\n", " torch.uint8)).cuda()\n", " masked_msa_tokens = MSA_tokens * mask + mask_idx * (1 - mask)\n", " results = self.msa_transformer(masked_msa_tokens,\n", " repr_layers=[12],\n", " return_contacts=False)\n", " msa_logits = self.softmax_tensor(x=results[\"logits\"], axis=3, T=T)\n", " if use_pdf == False:\n", " new_msa_tokens = torch.argmax(msa_logits, dim=3)\n", " else:\n", " Vals = torch.tensor([\n", " 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,\n", " 20, 21, 22, 23, 30\n", " ],\n", " dtype=torch.int64)\n", " maxval = Vals[-1].cuda()\n", " msa_logits = msa_logits[:, :, :, Vals]\n", " msa_logits = msa_logits / (torch.sum(msa_logits,\n", " axis=3)[:, :, :, None])\n", " cum = torch.cumsum(msa_logits, dim=3)\n", " idxs = torch.zeros_like(cum, dtype=torch.int64).cuda()\n", " idxs1 = Vals[None, None, None, :].cuda()\n", " idxs = idxs + idxs1\n", " sample = (torch.rand(\n", " (cum.shape[0], cum.shape[1], cum.shape[2]))).cuda()\n", " idxs[torch.gt(sample[:, :, :, None], cum)] = 100\n", " new_msa_tokens = torch.minimum(torch.amin(idxs, axis=3),\n", " maxval)\n", " del cum, idxs, idxs1, sample\n", " if sample_all == False:\n", " new_msa_tokens = MSA_tokens * mask + new_msa_tokens * (1 - mask)\n", " new_msa_tokens[:, :, 0] = 0\n", " del mask, masked_msa_tokens, results, msa_logits\n", " return new_msa_tokens\n", "\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " def NEW_MSA(self, use_pdf=False, simplified=False, sample_all=False, T=1):\n", " \"\"\"\n", " Generate a new MSA by iteratively calling the masked MSA generator defined in: `self.generate_MSA`.\n", "\n", " ---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n", " or contacs), otherwise use `simplified`=True.\n", "\n", " The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n", " the tokens should be saved. The last element of the array gives the maximum number of iterations that should be done.\n", "\n", " `use_pdf`: if it's True the function sample the token from the logits pdf \n", " instead of getting the argmax (greedy sampling).\n", "\n", " `sample_all`: if True all the new tokens are obtained from the logits (both\n", " the masked and the non masked), if False the non masked tokens\n", " are left untouched and only the masked ones are changed.\n", "\n", " `T`: Temperature of sampling from the pdf of output logits.\n", " \"\"\"\n", " if self.iterations is None or self.p_mask is None:\n", " raise ValueError(\n", " \"Both `iterations` (numpy array) and `p_mask` (float) must be specified to generate a new MSA\"\n", " )\n", " max_iter = self.iterations[-1]\n", " with torch.no_grad():\n", " new_msa_tokens = self.msa_batch_tokens.clone()\n", " all_tokens = torch.zeros(\n", " (len(self.iterations), self.msa_batch_tokens.shape[0],\n", " self.msa_batch_tokens.shape[1],\n", " self.msa_batch_tokens.shape[2]),\n", " dtype=torch.int64)\n", " if simplified:\n", " all_tokens = all_tokens.to(dtype=torch.int8)\n", " if self.msa_alphabet.mask_idx != 32:\n", " raise ValueError(\n", " f\"The token used for masking is {self.msa_alphabet.mask_idx} instead of 32\"\n", " )\n", " # Iterate the MSA generation process\n", " j = 0\n", " for i in range(max_iter):\n", " new_msa_tokens = self.generate_MSA(\n", " MSA_tokens=new_msa_tokens,\n", " mask_idx=self.msa_alphabet.mask_idx,\n", " use_pdf=use_pdf, sample_all=sample_all, T=T)\n", " if np.any((i + 1) == self.iterations):\n", " # Save the tokens at the specified iterations\n", " if simplified:\n", " all_tokens[j,\n", " ...] = (new_msa_tokens.clone().detach().cpu()).to(\n", " dtype=torch.int8)\n", " else:\n", " all_tokens[j, ...] = new_msa_tokens.clone()\n", " j += 1\n", " del new_msa_tokens\n", " if simplified:\n", " return all_tokens.numpy()\n", " else:\n", " return all_tokens.cuda()\n", "\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " def Batch_MSA(self, use_pdf=False, simplified=False, repetitions=2, sample_all=False, T=1, phylo=False):\n", " \"\"\"\n", " Generate a full MSA by calling with different input MSAs the iterative MSA generator defined\n", " in: `self.NEW_MSA`.\n", "\n", " ---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n", " or contacs), otherwise use `simplified`=True\n", "\n", " The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n", " the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n", "\n", " `repetitions`: the number of times self.NEW_MSA() is repeated with a different input MSA.\n", "\n", " `use_pdf`: if it's True the function sample the token from the logits pdf \n", " instead of getting the argmax (greedy sampling).\n", "\n", " `sample_all`: if True all the new tokens are obtained from the logits (both\n", " the masked and the non masked), if False the non masked tokens\n", " are left untouched and only the masked ones are changed.\n", "\n", " `T`: Temperature of sampling from the pdf of output logits.\n", "\n", " `phylo`: if True the start sequences are sampled from phylogeny weights instead of randomly.\n", " \"\"\"\n", " with torch.no_grad():\n", " all_tokens = np.zeros(\n", " (len(self.iterations), self.msa_batch_tokens.shape[0],\n", " self.msa_batch_tokens.shape[1] * repetitions,\n", " self.msa_batch_tokens.shape[2]),\n", " dtype=np.int64)\n", " if simplified:\n", " all_tokens = all_tokens.astype('int8')\n", " ALL_tokens = self.msa_data\n", " depth = self.msa_batch_tokens.shape[1]\n", " if repetitions * depth > ALL_tokens.shape[1]:\n", " all_tokens = np.zeros(\n", " (len(self.iterations), self.msa_batch_tokens.shape[0],\n", " ALL_tokens.shape[1], self.msa_batch_tokens.shape[2]),\n", " dtype=np.int64)\n", "\n", " if not phylo:\n", " ALL_tokens = ALL_tokens[:, torch.randperm(ALL_tokens.shape[1]), :]\n", " else:\n", " _ = self.Weights_Phylogeny(ALL_tokens[0, :20, :], delta=0.8)\n", " phylo_w = self.Weights_Phylogeny(ALL_tokens[0, :, :], delta=0.8)\n", " indxs = torch.multinomial(phylo_w, ALL_tokens.shape[1], replacement=True)\n", " ALL_tokens = ALL_tokens[:, indxs, :]\n", " for i in range(repetitions):\n", " ind = torch.arange(i * depth, (i + 1) * depth)\n", " if (i + 1) * depth > ALL_tokens.shape[1]:\n", " ind = torch.arange(i * depth, ALL_tokens.shape[1])\n", " self.msa_batch_tokens = ALL_tokens[:, ind, :]\n", " self.msa_batch_tokens = self.msa_batch_tokens.cuda()\n", " all_tokens[:, :,\n", " ind.numpy(), :] = self.NEW_MSA(use_pdf=use_pdf, simplified=simplified, sample_all=sample_all, T=T)\n", " if (i + 1) * depth > ALL_tokens.shape[1]:\n", " break\n", "\n", " if simplified:\n", " return (ALL_tokens[:, :repetitions *\n", " depth, :].numpy()).astype('int8'), all_tokens\n", " else:\n", " return ALL_tokens[:, :repetitions *\n", " depth, :], torch.from_numpy(all_tokens).cuda()\n", "\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", "\n", " def generate_MSA_context(self, ancestor, context, mask_idx=32, use_pdf=False, sample_all=False, T=1):\n", " \"\"\"\n", " Generate a new sequence by masking some entries of the original ancestor sequence and\n", " re-predicting them through the transformer model (mask only `ancestor`, not the `context`).\n", "\n", " `ancestor`: input sequence to be masked iteratively.\n", "\n", " `context`: context MSA (not masked).\n", "\n", " `p_mask`: probability that an entry of the MSA is masked.\n", "\n", " `mask_idx`: masking index (as interpreted by the model), for MSA-Tr it's 32.\n", "\n", " `use_pdf`: if it's True the function sample the token from the logits pdf \n", " instead of getting the argmax (greedy sampling).\n", "\n", " `sample_all`: if True all the new tokens are obtained from the logits (both\n", " the masked and the non masked), if False the non masked tokens\n", " are left untouched and only the masked ones are changed.\n", "\n", " `T`: Temperature of sampling from the pdf of output logits.\n", " \"\"\"\n", " with torch.no_grad():\n", "\n", " if not ancestor.is_cuda:\n", " ancestor = ancestor.cuda()\n", " if not context.is_cuda:\n", " context = context.cuda()\n", "\n", " mask = ((torch.rand(ancestor.shape) > self.p_mask).type(torch.uint8)).cuda()\n", " masked_ancestor = ancestor * mask + mask_idx * (1 - mask)\n", " \n", " masked_msa_tokens = torch.zeros((context.shape[0],\n", " context.shape[1]+1,\n", " context.shape[2]),\n", " dtype=torch.int64).cuda()\n", " masked_msa_tokens[0, 0, :] = masked_ancestor\n", " masked_msa_tokens[:, 1:, :] = context\n", "\n", " results = self.msa_transformer(masked_msa_tokens,\n", " repr_layers=[12],\n", " return_contacts=False)\n", " results1 = results[\"logits\"][:,0,:,:]\n", " results1 = results1[:,None,:,:]\n", " msa_logits = self.softmax_tensor(x=results1, axis=3, T=T)\n", "\n", " if use_pdf == False:\n", " new_generation = torch.argmax(msa_logits, dim=3)[0,0,:]\n", " else:\n", " Vals = torch.tensor([\n", " 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,\n", " 20, 21, 22, 23, 30],dtype=torch.int64)\n", " maxval = Vals[-1].cuda()\n", " msa_logits = msa_logits[:, :, :, Vals]\n", " msa_logits = msa_logits / (torch.sum(msa_logits,\n", " axis=3)[:, :, :, None])\n", " cum = torch.cumsum(msa_logits, dim=3)\n", " idxs = torch.zeros_like(cum, dtype=torch.int64).cuda()\n", " idxs1 = Vals[None, None, None, :].cuda()\n", " idxs = idxs + idxs1\n", " sample = (torch.rand(\n", " (cum.shape[0], cum.shape[1], cum.shape[2]))).cuda()\n", " idxs[torch.gt(sample[:, :, :, None], cum)] = 100\n", " new_generation = torch.minimum(torch.amin(idxs, axis=3),\n", " maxval)[0,0,:]\n", " del cum, idxs, idxs1, sample\n", " \n", " if sample_all == False:\n", " new_generation = ancestor * mask + new_generation * (1 - mask)\n", " new_generation[0] = 0\n", "\n", " del mask, masked_msa_tokens, results, results1, msa_logits\n", " return new_generation\n", "\n", " #-------------------------------------------------------------------------------------------------------------------\n", " # Generate new sequence in a Linear tree by reiterating the function `generate_MSA_context()` starting from the sequence:\n", " # `ancestor` (original sequence) and using the sequences in `context` as context MSA.\n", " def Context_MSA(self, depth=None, ancestor=None, context=None, use_pdf=False, simplified=False, sample_all=False, print_all=True, T=1):\n", " \"\"\"\n", " Generates a new MSA with context-generation by iterating the masking on the original ancestor sequence\n", " using: `self.generate_MSA_context`. It masks `ancestor` (original sequence) and uses the sequences in `context` as context MSA.\n", " \n", " ---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n", " or contacs), otherwise use `simplified`=True\n", "\n", " The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n", " the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n", " If `print_all`=True then it saves the generated sequences at each iteration.\n", "\n", " `ancestor`: input sequence to be masked iteratively.\n", "\n", " `context`: context MSA (not masked).\n", "\n", " `use_pdf`: if it's True the function sample the token from the logits pdf \n", " instead of getting the argmax (greedy sampling).\n", "\n", " `sample_all`: if True all the new tokens are obtained from the logits (both\n", " the masked and the non masked), if False the non masked tokens\n", " are left untouched and only the masked ones are changed.\n", "\n", " `T`: Temperature of sampling from the pdf of output logits.\n", "\n", " `depth`: number of generated sequences, if None the depth is the number of ancestor sequences.\n", " \"\"\"\n", " with torch.no_grad():\n", " total_ran=False\n", " if ancestor is None and context is None and depth is not None:\n", " ALL_tokens = self.msa_data\n", " ALL_tokens = ALL_tokens[:, torch.randperm(ALL_tokens.shape[1]), :]\n", " ancestor = ALL_tokens[0,:depth,:]\n", " ALL_tokens = ALL_tokens[:, torch.randperm(ALL_tokens.shape[1]), :]\n", " context = ALL_tokens[:,:self.msa_batch_tokens.shape[1],:]\n", " elif depth is None:\n", " depth = ancestor.shape[0]\n", " if isinstance(context,np.ndarray):\n", " total_ran=False\n", " elif context=='tot-ran':\n", " total_ran=True\n", " else:\n", " print('ERROR, either you give depth or you give ancestor and context')\n", "\n", " all_tokens = torch.zeros((self.msa_batch_tokens.shape[0],\n", " self.iterations[-1]+1,\n", " depth,\n", " ancestor.shape[1]),\n", " dtype=torch.int64).cuda()\n", "\n", " ancestor = torch.from_numpy(ancestor).to(dtype=torch.int64)\n", " if not total_ran:\n", " context = torch.from_numpy(context).to(dtype=torch.int64)\n", " if total_ran:\n", " ALL_tokens = self.msa_data\n", "\n", " all_tokens[0, 0, :, :] = ancestor\n", "\n", " if simplified:\n", " all_tokens = all_tokens.to(dtype=torch.int8)\n", " if self.msa_alphabet.mask_idx != 32:\n", " raise ValueError(\n", " f\"The token used for masking is {self.msa_alphabet.mask_idx} instead of 32\"\n", " )\n", " \n", " # Iterate the MSA generation tree\n", " for j in range(depth):\n", " new_ancestor = all_tokens[0, 0, j, :]\n", " for i in range(1,self.iterations[-1]+1):\n", " if total_ran:\n", " context = (ALL_tokens[:, torch.randperm(ALL_tokens.shape[1])[:self.msa_batch_tokens.shape[1]], :]).cuda()\n", " new_ancestor = self.generate_MSA_context(ancestor=new_ancestor,context=context, mask_idx=self.msa_alphabet.mask_idx, use_pdf=use_pdf, sample_all=sample_all, T=T)\n", " if print_all:\n", " all_tokens[0, i, j, :] = new_ancestor\n", " if not print_all:\n", " all_tokens[0, -1, j, :] = new_ancestor\n", " # torch.cuda.empty_cache()\n", "\n", " if not print_all:\n", " all_tokens = all_tokens[:,torch.tensor([-1]),:,:]\n", "\n", " if simplified:\n", " return ((context.detach().cpu()).to(dtype=torch.int8)).numpy(), ((all_tokens.detach().cpu()).to(dtype=torch.int8)).numpy()\n", " else:\n", " return context.cuda(), all_tokens.cuda()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
class IM_MSA_Transformer[source]IM_MSA_Transformer(**`iterations`**=*`None`*, **`p_mask`**=*`None`*, **`filename`**=*`None`*, **`num`**=*`None`*, **`filepath`**=*`None`*)\n",
"\n",
"Class that implement the Iterative masking algorithm"
],
"text/plain": [
"IM_MSA_Transformer.Batch_MSA[source]IM_MSA_Transformer.Batch_MSA(**`use_pdf`**=*`False`*, **`simplified`**=*`False`*, **`repetitions`**=*`2`*, **`sample_all`**=*`False`*, **`T`**=*`1`*, **`phylo`**=*`False`*)\n",
"\n",
"Generate a full MSA by calling with different input MSAs the iterative MSA generator defined\n",
"in: `self.NEW_MSA`.\n",
"\n",
"---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n",
" or contacs), otherwise use `simplified`=True\n",
"\n",
"The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n",
"the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n",
"\n",
"`repetitions`: the number of times self.NEW_MSA() is repeated with a different input MSA.\n",
"\n",
"`use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
"`sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
"`T`: Temperature of sampling from the pdf of output logits.\n",
"\n",
"`phylo`: if True the start sequences are sampled from phylogeny weights instead of randomly."
],
"text/plain": [
"IM_MSA_Transformer.Context_MSA[source]IM_MSA_Transformer.Context_MSA(**`depth`**=*`None`*, **`ancestor`**=*`None`*, **`context`**=*`None`*, **`use_pdf`**=*`False`*, **`simplified`**=*`False`*, **`sample_all`**=*`False`*, **`print_all`**=*`True`*, **`T`**=*`1`*)\n",
"\n",
"Generates a new MSA with context-generation by iterating the masking on the original ancestor sequence\n",
"using: `self.generate_MSA_context`. It masks `ancestor` (original sequence) and uses the sequences in `context` as context MSA.\n",
"\n",
"---> Use this function with `simplified`=False only if you need tokens in cuda ! (i.e. if you want to compute embed\n",
" or contacs), otherwise use `simplified`=True\n",
"\n",
"The variable `self.iterations` must be a numpy array which specifies when (at which iterations)\n",
"the tokens must be saved. The last element of the array gives the maximum number of iterations that should be done.\n",
"If `print_all`=True then it saves the generated sequences at each iteration.\n",
"\n",
"`ancestor`: input sequence to be masked iteratively.\n",
"\n",
"`context`: context MSA (not masked).\n",
"\n",
"`use_pdf`: if it's True the function sample the token from the logits pdf \n",
" instead of getting the argmax (greedy sampling).\n",
"\n",
"`sample_all`: if True all the new tokens are obtained from the logits (both\n",
" the masked and the non masked), if False the non masked tokens\n",
" are left untouched and only the masked ones are changed.\n",
"\n",
"`T`: Temperature of sampling from the pdf of output logits.\n",
"\n",
"`depth`: number of generated sequences, if None the depth is the number of ancestor sequences."
],
"text/plain": [
"gen_MSAs[source]gen_MSAs(**`filepath`**:\"Path of the input directory\", **`filename`**:\"Name of the input file(s)\", **`new_dir`**:\"Name of the output directory\", **`pdf`**:\"Should I sample tokens from the pdf ? (bool)\", **`T`**:\"Which is the sampling Temperature from the pdf ? (only when `pdf` is True)\", **`sample_all`**:\"Should I sample all tokens or just the masked ones ? (True = sample all tokens)\", **`Iters`**:\"Number of total iterations to generate the new tokens\", **`pmask`**:\"Masking probability\", **`num`**:\"Size of the batches MSAs which the MSA-Transformer receives as input\", **`depth`**:\"Number of batches (of size num) that you want to generate\", **`generate`**:\"How should I generate sequences ? False (=Batch generation) or Linear with context (=linear-ran/linear-tot-ran), `-ran` means that the context MSA is sampled randomly (once) while `-tot-ran` means that it is sampled randomly each time.\", **`print_all`**:\"Should I print the MSA after each iteration ? (bool)\", **`range_vals`**:\"First and last index of the sequences that you want to use as ancestors\", **`phylo_w`**:\"Should I sample the starting sequences from the phylogeny weights ? (bool)\")\n",
"\n",
"Generate a new MSA either with Batch generation of Context generation. It shuffles the initial MSA and uses different slices as batch MSAs"
],
"text/plain": [
"