MSA_next.py 6.15 KB
Newer Older
yuhai's avatar
yuhai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import pickle
from fastcore.script import *
from MSA import *

@call_parse
def gen_MSAs(filepath:Param(help='Path of the input directory',type=str,default='/Iterative_masking-master/examples'),
         filename:Param(help='Name of the input file(s)',type=str,nargs='+',default=["PF00072.fasta"]),
         new_dir:Param(help='Name of the output directory',type=str,default='results_new'),
         pdf:Param(help='Should I sample tokens from the pdf ? (bool)',type=bool_arg,default=False),
         T:Param(help='Which is the sampling Temperature from the pdf ? (only when `pdf` is True)',type=float,default=1),
         sample_all:Param(help='Should I sample all tokens or just the masked ones ? (True = sample all tokens)',type=bool_arg, default=False),
         Iters:Param(help='Number of total iterations to generate the new tokens',type=int,default=20),
         pmask:Param(help='Masking probability',type=float,default=0.1),
         num:Param(help='Size of the batches MSAs which the MSA-Transformer receives as input',type=int,nargs='+',default=[10]),
         depth:Param(help='Number of batches (of size num) that you want to generate',type=int,default=10),
         generate:Param(help='How should I generate sequences ? False (=Batch generation) or Linear with context (=linear-ran/linear-tot-ran), `-ran` means that the context MSA is sampled randomly (once) while `-tot-ran` means that it is sampled randomly each time.',type=str, default=False),
         print_all:Param(help='Should I print the MSA after each iteration ? (bool)',type=bool_arg,default=False),
         range_vals:Param(help='First and last index of the sequences that you want to use as ancestors', type=int,nargs='+',default=False),
         phylo_w:Param(help='Should I sample the starting sequences from the phylogeny weights ? (bool)',type=bool_arg,default=False)
         ):
    "Generate a new MSA either with Batch generation of Context generation. It shuffles the initial MSA and uses different slices as batch MSAs"

    # Create folder
    path = os.getcwd()
    path1 = new_dir
    if new_dir is False:
        path1 = filename[0][:-6]
    try:
        os.mkdir(path + "/" + path1)
    except OSError:
        print("Creation of the directory %s failed" % (path + "/" + path1))
    else:
        print("Successfully created the directory %s " % (path + "/" + path1))

    # Save Input MSA
    print('Tokenize')
    Class = IM_MSA_Transformer(filename=filename,
                               num=num,
                               filepath=filepath)
    idx_list = Class.idx_list
    old_tkn = Class.print_tokens()
    a_file = open(path1 + "/dictionary-tokens.pkl", "wb")
    pickle.dump(idx_list, a_file)
    a_file.close()
    np.save(path1 + "/original-tokens.npy", old_tkn[0])

    add_strs = ""
    if pdf==True:
        add_strs += f"_pdf(T={round(T,3)})"
        print(
            "We are sampling new tokens from the pdf of logits and not taking the mode of the pdf"
        )
    if T!=1 and pdf==False:
        print('To sample with a Temperature you should use pdf=True, otherwise the result is the same')
    if sample_all == False:
        add_strs += "_(only-masked-sampled)"
    if not generate==False:
        add_strs += "_"+generate+"_(context-"+str(num[0])+")"
    if phylo_w:
        add_strs += "_phylo-w"

    print('Generate Class')
    Class = IM_MSA_Transformer(iterations=np.array([Iters]),
                               p_mask=pmask,
                               filename=filename,
                               num=num,
                               filepath=filepath)

    print('Compute results from Class')
    Class.iterations = np.array([Iters])
    Class.p_mask = pmask

    if generate == False:
        print('Generating MSA with same size as the original one')
        old_T, new_T = Class.Batch_MSA(simplified=True,
                                    repetitions=depth,
                                    use_pdf=pdf, sample_all=sample_all, T=T, phylo=phylo_w)
        NNN = min(num[0] * depth, old_T.shape[1])

    elif generate=='linear-ran' or generate=='linear-tot-ran':
        print('Generate MSA with linear context generation')
        orig_tkn = np.load(path + "/" + path1 + "/original-tokens.npy")
        # select ancestor and context
        np.random.seed(0)
        indices = np.random.permutation(orig_tkn.shape[0])
        indexes_context = indices[:num[0]]
        indices = np.random.permutation(orig_tkn.shape[0])
        if depth == -1:
            ind_ancestor = indices
        elif range_vals is False:
            ind_ancestor = indices[:depth]
        else:
            if range_vals[1] == -1 :
                ind_ancestor = indices[range_vals[0]:]
                range_vals[1] = orig_tkn.shape[0]
            else:
                ind_ancestor = indices[range_vals[0]:range_vals[1]]
        ancestor = orig_tkn[ind_ancestor,:]
        context  = orig_tkn[indexes_context,:][None,:,:]
        if generate=='linear-tot-ran':
            context = 'tot-ran'
        old_T, new_T = Class.Context_MSA(None, ancestor, context, use_pdf=pdf, simplified=True, sample_all=sample_all, print_all=print_all, T=T)
        if generate=='linear-tot-ran':
            old_T = ancestor[None,:,:]
        NNN = new_T.shape[2]
    else:
        print('ERROR: Select a generative process')

    # define the name of the directory to be created and create it
    path2 = "Generated" + "_iter-" + str(
        Iters) + "_pmask-" + str(pmask) + "_seqs-" + str(NNN) + add_strs
    try:
        os.mkdir(path + "/" + path1 + "/" + path2)
    except OSError:
        print("Creation of the directory %s failed" % (path + "/" +
              path1 + "/" + path2))
    else:
        print("Successfully created the directory %s " % (path + "/" +
              path1 + "/" + path2))

    # Save data
    if generate == False or generate=='linear-tot-ran':
        np.save(path1 + "/" + path2 + "/shuffled-tokens.npy", old_T[0])
    else:
        np.save(path1 + "/" + path2 + "/context-tokens.npy", old_T[0])
    str_add = ''
    if range_vals is not False:
        str_add = '_range_indx_'+str(range_vals[0])+','+str(range_vals[1])
    np.save(path1 + "/" + path2 + "/new-tokens"+str_add+".npy", new_T[0])

    return 1

print(show_doc(gen_MSAs))