CIRIdeep.py 39.2 KB
Newer Older
adaZ-9's avatar
adaZ-9 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Activation, Input, Lambda, LeakyReLU, Concatenate, BatchNormalization
from keras.optimizers import Adam
import keras.backend as K
import pandas as pd
import os
import numpy as np
from sklearn import metrics
import random
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import argparse as ap
from keras.utils import np_utils
import math
import config

adaZ-9's avatar
adaZ-9 committed
18
class clf():
adaZ-9's avatar
adaZ-9 committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    
    def __init__(self, n_in=6523, hidden=[1200, 500, 300, 200], drop_out=[0, 0, 0, 0, 0], batch_size=512, kernel_initializer='glorot_normal', learning_rate=0.001, splicing_amount=True, CIRIdeepA=False):
                
        if splicing_amount:
            n_in += 2
        input = Input(shape=(n_in, ), dtype='float32', name='input') 
        firstlayer = True
        if drop_out[0] > 0: 
            h = Dropout(drop_out[0])(input)
            firstlayer = False
        for i in range(len(hidden)):
            if firstlayer:
                h = Dense(hidden[i], activation=None, kernel_initializer=kernel_initializer)(input) 
                h = BatchNormalization()(h)
                h = Activation('relu')(h)
                firstlayer = False
            else:
                h = Dense(hidden[i], activation=None, kernel_initializer=kernel_initializer)(h)
                h = BatchNormalization()(h)
                h = Activation('relu')(h)
            if drop_out[i+1] > 0:
                h = Dropout(drop_out[i+1])(h)
        if not CIRIdeepA:
            output = Dense(1, activation='sigmoid', name='output')(h)
            model = Model(inputs=input, outputs=output)
            adam = Adam(lr=learning_rate)
            model.compile(optimizer=adam, loss='binary_crossentropy', metrics=['accuracy'])
        else:
            output = Dense(3, activation='softmax', name='output')(h)
            model = Model(inputs=input, outputs=output)
            adam = Adam(lr=learning_rate)
            model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

        self.model = model
        self.batch_size = batch_size
        self.n_in = n_in
        self.hidden = hidden
        self.drop_out = drop_out
        self.learning_rate = learning_rate
        self.CIRIdeepA = CIRIdeepA
        
    def fit(self, inputdata):
    
        best_metrics = train_on_balanced_batch(self.model, inputdata, batch_size=self.batch_size, CIRIdeepA=self.CIRIdeepA)
    
        return best_metrics
        
    def predict(self, test_data):
    
        Y_pred = self.model.predict([test_data['X']])
        
        return Y_pred
    
def generate_data_batch(train_list, test_data, seqFeature_df, geneExp_absmax, geneExp_colnames, odir, RBP_dir, splicing_max='', splicing_dir='', splicing_amount=False, n_epoch=100, CIRIdeepA=False):
    '''
    generate dataset for training while fill test data
    '''
    
    # epoch looping
    
    # test_data_saved = False
    n_epoch = config.n_epoch
    
    for i in range(n_epoch):
        
        sample_list = np.array(list(train_list.index))
        
        # 1 epoch
        while len(sample_list) > 0:
            
            # randomly select 4 sample comparisons
            random.seed(config.random_seed_eval)
            idx_inbatch = random.sample(range(len(sample_list)), min(4, len(sample_list)))
            sample_inbatch = sample_list[idx_inbatch]
            sample_list = np.delete(sample_list, idx_inbatch)
            print('epoch: %i; sample in batch: %s' %(i, ', '.join(sample_inbatch)))
            
            X_train_list = []
            Y_train_list = []
            rownames_train_list = []
            X_test_list = []
            Y_test_list = []
            rownames_test_list = []
            
            for sample in sample_inbatch:
            
                label_fn = train_list.loc[sample].label_fn
                X_tmp, Y_tmp, rownames_tmp, colnames = \
                    construct_data_from_label(label_fn, sample, seqFeature_df, geneExp_absmax, geneExp_colnames, RBP_dir, splicing_dir=splicing_dir, splicing_max=splicing_max, splicing_amount=False, phase='train', CIRIdeepA=CIRIdeepA)
                inputdata = {'X':X_tmp, 'Y':Y_tmp, 'rownames':rownames_tmp}
                X_train_tmp, Y_train_tmp, X_test_tmp, Y_test_tmp, rownames_train_tmp, rownames_test_tmp = \
                    split_testdata_out(inputdata, CIRIdeepA=CIRIdeepA)
                    
                X_train_list.append(X_train_tmp)
                Y_train_list.append(Y_train_tmp)
                rownames_train_list.append(rownames_train_tmp)
                X_test_list.append(X_test_tmp)
                Y_test_list.append(Y_test_tmp)
                rownames_test_list.append(rownames_test_tmp)

            # concatenate 4 data
            X_train = np.vstack(X_train_list)
            if not CIRIdeepA:
                Y_train = np.hstack(Y_train_list)
                Y_test = np.hstack(Y_test_list)
            else:
                Y_train = np.vstack(Y_train_list)
                Y_test = np.vstack(Y_test_list)
            rownames_train = np.hstack(rownames_train_list)
            X_test = np.vstack(X_test_list)
            rownames_test = np.hstack(rownames_test_list)

            # add test dataset into test_data in epoch 0
            if i == 0:
                test_data['X'].append(X_test)
                test_data['Y'].append(Y_test)
                test_data['rownames'].append(rownames_test)
            
            # generate train dataset for 4 samples
            train_data = {'X':X_train, 'Y':Y_train, 'rownames':rownames_train, 'colnames':colnames}
            yield train_data

    return
            

def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12345, CIRIdeepA=False):
    '''
    split data into test data and train data
    '''
    test_prop = config.test_prop
    random_seed = config.random_seed_test
    
    X = inputdata['X']
    Y = inputdata['Y']
    rownames = inputdata['rownames']
    
    if not CIRIdeepA:
        # half pos half neg in test data
adaZ-9's avatar
adaZ-9 committed
157
        test_size = int(len(rownames) * test_prop * 0.5) 
adaZ-9's avatar
adaZ-9 committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        pos_idx = np.where(Y == 1)[0]
        neg_idx = np.where(Y == 0)[0]
        
        # randomly select test data
        np.random.seed(random_seed)
        np.random.shuffle(pos_idx) ### random seed set
        np.random.seed(random_seed)
        np.random.shuffle(neg_idx) ### random seed set
        
        pos_idx_test = pos_idx[:test_size]
        pos_idx_train = pos_idx[test_size:]
        neg_idx_test = neg_idx[:test_size]
        neg_idx_train = neg_idx[test_size:]
        
        X_test = np.vstack((X[pos_idx_test, ], X[neg_idx_test, ]))
        X_train = np.vstack((X[pos_idx_train, ], X[neg_idx_train, ]))
        Y_test = np.asarray([1.]*len(pos_idx_test) + [0.]*len(neg_idx_test))
        Y_train = np.asarray([1.]*len(pos_idx_train) + [0.]*len(neg_idx_train))
        rownames_test = np.hstack((rownames[pos_idx_test], rownames[neg_idx_test]))
        rownames_train = np.hstack((rownames[pos_idx_train], rownames[neg_idx_train]))
   
    else:
        test_size = int(len(rownames) * test_prop)
        np.random.seed(random_seed)
        idx_shuffled = list(range(len(rownames)))
        for i in range(5):
            np.random.shuffle(idx_shuffled)
        test_idx = idx_shuffled[:test_size]
        train_idx = idx_shuffled[test_size:]
        X_test = X[test_idx, ]
        X_train = X[train_idx, ]
        Y_test = Y[test_idx, ]
        Y_train = Y[train_idx, ]
        rownames_test = rownames[test_idx]
        rownames_train = rownames[train_idx]
    
    return X_train, Y_train, X_test, Y_test, rownames_train, rownames_test


def split_eval_train_data(inputdata, n_fold=5): #######
    '''
    split train data into train and validation
    '''
    X = inputdata['X']
    Y = inputdata['Y']
    rownames = inputdata['rownames']
    
    # randomly select 1/5 data from whole dataset
    fold_size = int(len(rownames)/n_fold)
    idx = np.asarray(range(0, len(rownames)))
    np.random.seed(config.random_seed_eval)
    np.random.shuffle(idx)
    
    idx_val = idx[:fold_size]
    idx_train = idx[fold_size:]
    
    X_val = X[idx_val, ]
    X_train = X[idx_train, ]
    Y_val = Y[idx_val, ]
    Y_train = Y[idx_train, ]
    
    rownames_val = rownames[idx_val]
    rownames_train = rownames[idx_train]
    
    return X_train, X_val, Y_train, Y_val, rownames_train, rownames_val
    

def split_data_into_balanced_minibatch(Y_train, batch_size=64, pos_prop=0.5, CIRIdeepA=False):
    '''
    split train data into balanced mini-batches
    '''
    pos_prop = config.pos_prop
    
    if not CIRIdeepA:
        pos_idx = np.where(Y_train == 1)[0] # shuffle
        neg_idx = np.where(Y_train == 0)[0]
    else:
        neg_idx = np.where(Y_train[:,0] == 1)[0]
        pos_idx = np.where(Y_train[:,0] == 0)[0]
    
    if len(pos_idx) + len(neg_idx) < batch_size:
        batch_size = 2 * min(len(pos_idx), len(neg_idx))
    
    # half pos half neg in batch
    pos_size = int(batch_size * pos_prop)
    neg_size = batch_size - pos_size
    
    pos_batch = len(pos_idx)//pos_size
    neg_batch = len(neg_idx)//neg_size
    n_batch = max(pos_batch, neg_batch)
    
    for i in range(5):
        np.random.shuffle(pos_idx)
    pos_idx = np.tile(pos_idx, n_batch)
    
    for i in range(5):
        np.random.shuffle(neg_idx)
    neg_idx = np.tile(neg_idx, n_batch)
    
    idx_inbatch_list = []
    for i in range(0, n_batch):
        pos_idx_inbatch = pos_idx[i*pos_size:(i+1)*pos_size]
        neg_idx_inbatch = neg_idx[i*neg_size:(i+1)*neg_size]
        idx_inbatch = np.hstack((pos_idx_inbatch, neg_idx_inbatch))
        np.random.shuffle(idx_inbatch)
        idx_inbatch_list.append(idx_inbatch)
        
    return idx_inbatch_list

###################################
def train_on_balanced_batch(model, inputdata, batch_size=64, validation_freq=10, CIRIdeepA=False):
    '''
    train and validation (for each sample pair)
    '''
    
    # construct balanced minibatch
    Y_train = inputdata['Y_train']
    idx_inbatch_list = split_data_into_balanced_minibatch(Y_train, batch_size=batch_size, CIRIdeepA=CIRIdeepA)
    
    # initiate
    best_roc = 0
    best_loss = float('Inf')
    n_batch = len(idx_inbatch_list)
    
    # train on batch for n times
adaZ-9's avatar
adaZ-9 committed
283
    for i in range(0, n_batch): 
adaZ-9's avatar
adaZ-9 committed
284
285
286
287
288
289
290
291
292
293
        
        X_train, Y_train, X_val, Y_val, rownames_train, rownames_val = \
            inputdata['X_train'], inputdata['Y_train'], inputdata['X_val'], inputdata['Y_val'], inputdata['rownames_train'], inputdata['rownames_val']
        loss_on_batch = model.train_on_batch(X_train[idx_inbatch_list[i], ], Y_train[idx_inbatch_list[i], ]) # debug: loss unstable
        print('loss on batch%i: %.3f' % (i, loss_on_batch[0]))
        print('accuracy on batch%i: %.3f' % (i, loss_on_batch[1]))
            
        if i == n_batch - 1:
            
            Y_predict = model.predict(X_val)
adaZ-9's avatar
adaZ-9 committed
294
295
            eval_val = model.evaluate(X_val, Y_val, verbose=1) 
            eval_train = model.evaluate(X_train, Y_train, verbose=1) 
adaZ-9's avatar
adaZ-9 committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            
            if not CIRIdeepA:
                roc = metrics.roc_auc_score(Y_val, Y_predict)
                pr = metrics.average_precision_score(Y_val, Y_predict)
            else:
                Y_predict_categorical = (Y_predict == Y_predict.max(axis=1, keepdims=True)).astype('float')
                precision_score_macro = metrics.precision_score(Y_val, Y_predict_categorical, average="macro")
                precision_score_weighted = metrics.precision_score(Y_val, Y_predict_categorical, average="weighted")
                recall_macro = metrics.recall_score(Y_val, Y_predict_categorical, average='macro')
                recall_weighted = metrics.recall_score(Y_val, Y_predict_categorical, average="weighted")
                F1_macro = metrics.f1_score(Y_val, Y_predict_categorical, average='macro')
                F1_weighted = metrics.f1_score(Y_val, Y_predict_categorical, average='weighted')

    print('*' * 50)
    if not CIRIdeepA:
        print('batch %i: loss_val - %.3f; accuracy_val - %.3f; loss_train - %.3f; accuracy_train - %.3f; auroc_val - %.3f; aupr_val - %.3f' % (i, eval_val[0], eval_val[1], eval_train[0], eval_train[1], roc, pr))
    else:
        print('batch %i: loss_val - %.3f; accuracy_val - %.3f; loss_train - %.3f; accuracy_train - %.3f' % (i, eval_val[0], eval_val[1], eval_train[0], eval_train[1]))
    print('*' * 50)
    
    if not CIRIdeepA:
        return {'roc':roc, 'loss_eval':eval_val, 'loss_train':eval_train, 'aupr':pr}
    else:
        return {'loss_train':eval_train[0], 'acc_train':eval_train[1], 'loss_val':eval_val[0], 'acc_val':eval_val[1], 
        'precision_score_macro':precision_score_macro, 'precision_score_weighted':precision_score_weighted, 
        'recall_macro':recall_macro, 'recall_weighted':recall_weighted, 
        'F1_macro':F1_macro, 'F1_weighted':F1_weighted}

###################################
def read_label_fn(label_fn, min_read_cov=20, significance=0.1, CIRIdeepA=False):
    '''
    get circid and label for one sample comparison
    return pos(dict), neg(dict)
    '''
    if not CIRIdeepA:
        
        pos = {}
        neg = {}
        
        with open(label_fn, 'r') as f:
        
            firstline=True
            
            for line in f:
            
                ele = line.strip().split('\t')
                
                if firstline:
                
adaZ-9's avatar
adaZ-9 committed
345
                    header = {ele[i]:i for i in range(len(ele))} 
adaZ-9's avatar
adaZ-9 committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
                    firstline = False
                    continue
                
                post_pr = float(ele[header['post_pr']])
                eid = ele[header['ID']]
                I1 = int(ele[header['I1']])
                I2 = int(ele[header['I2']])
                S1 = int(ele[header['S1']])
                S2 = int(ele[header['S2']])
                delta = float(ele[header['delta.mle']])
                
                    
                if I1+S1 <= min_read_cov or I2+S2 <= min_read_cov:
                    continue
                if post_pr > 1-significance and min(S1, S2) > 2:
                    pos[eid] = post_pr
                elif post_pr < significance and min(I1, S1)>2 and min(I2, S2)>2:
                    neg[eid] = post_pr
                    
        return pos, neg
    
    else:
        no_diff = {}
        a_higher = {}
        b_higher = {}
        
        with open(label_fn, 'r') as f:
        
            firstline=True
            
            for line in f:
            
                ele = line.strip().split('\t')
                
                if firstline:
                
adaZ-9's avatar
adaZ-9 committed
382
                    header = {ele[i]:i for i in range(len(ele))} 
adaZ-9's avatar
adaZ-9 committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
                    firstline = False
                    continue
                
                post_pr = float(ele[header['post_pr']])
                eid = ele[header['ID']]
                I1 = int(ele[header['I1']])
                I2 = int(ele[header['I2']])
                S1 = int(ele[header['S1']])
                S2 = int(ele[header['S2']])
                
                #***direction***#
                delta = float(ele[header['delta.mle']])
                
                if I1+S1 <= min_read_cov or I2+S2 <= min_read_cov:
                    continue
                
                if post_pr > 1-significance and min(S1, S2) > 2 and delta < 0:
                    a_higher[eid] = post_pr
                
                if post_pr > 1-significance and min(S1, S2) > 2 and delta > 0:
                    b_higher[eid] = post_pr
                
                elif post_pr < significance and min(I1, S1)>2 and min(I2, S2)>2:
                    no_diff[eid] = post_pr

        return no_diff, a_higher, b_higher

def read_label_fn_predict(label_fn):
    with open(label_fn, 'r') as f:
        circlist = f.read().strip().split('\n')
    return circlist

def read_train_list(fn):
    '''
    return dataframe(sampleID label_fn)
    '''
    df = pd.read_table(fn, index_col=0)    # ID label_fn num_pos num_neg geneExp_fn
    return df

def read_sequence_feature(seqFeature_fn):

    data = pd.read_csv(seqFeature_fn, sep='\t', index_col=0)
    data = data.iloc[:, 13:]

    return data
    
def read_geneExp_absmax(fn):

    df = pd.read_csv(fn, index_col=0, sep='\t')['max']
    colnames = list(df.index)
    rbp_max = df.values

    return rbp_max, colnames

def read_geneExp(sample, geneExp_absmax, nrow=1, phase='train', RBPexp_dir=''):

    geneExp_fn = os.path.join(RBPexp_dir, sample + '_rpb.csv')
    df = pd.read_csv(geneExp_fn, index_col=0, sep='\t').transpose()
adaZ-9's avatar
adaZ-9 committed
441
    vec = df.values.flatten() / geneExp_absmax 
adaZ-9's avatar
adaZ-9 committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    if phase == 'predict':
        vec[vec > 1] = 1
    mat = np.tile(vec, (nrow, 1))

    return mat

def read_splicing_amount(sample1, sample2, splicing_max, eid_list, phase='train', splicing_dir=''):

    splicing_amount_max_eidlist = splicing_max.loc[eid_list]
    
    splicing_amount_fn1 = os.path.join(splicing_dir, sample1+'.output')
    splicing_amount_fn2 = os.path.join(splicing_dir, sample2+'.output')
    
    df1 = pd.read_csv(splicing_amount_fn1, index_col=0, sep='\t')
    df1 = df1.loc[eid_list, 'splicing_amount']
    df1 = df1/splicing_amount_max_eidlist['max_splicing_amount']
    if phase == 'predict':
        df1[df1 > 1] = 1
    df1 = np.asarray(df1)
    df1 = df1.reshape(df1.shape[0], 1)
    
    df2 = pd.read_csv(splicing_amount_fn2, index_col=0, sep='\t')
    df2 = df2.loc[eid_list, 'splicing_amount']
    df2 = df2/splicing_amount_max_eidlist['max_splicing_amount']
    if phase == 'predict':
        df2[df2 > 1] = 1
    df2 = np.asarray(df2)
    df2 = df2.reshape(df2.shape[0], 1)
    
    colnames = ['splicing_amount'] ##

    return df1, df2, colnames

def construct_data_from_label(label_fn, sample, seqFeature_df, geneExp_absmax, geneExp_colnames, RBPexp_dir, splicing_max='', splicing_dir='', splicing_amount=False, sep='_', CIRIdeepA=False, phase='train'):
    
    if CIRIdeepA:
        if phase == 'train':
            no_diff, a_higher, b_higher = read_label_fn(label_fn, CIRIdeepA=CIRIdeepA)
            #***direction***#
            no_diff_eid = list(seqFeature_df.index.intersection(no_diff.keys()))
            a_higher_eid = list(seqFeature_df.index.intersection(a_higher.keys()))
            b_higher_eid = list(seqFeature_df.index.intersection(b_higher.keys()))
            eid_list = no_diff_eid + a_higher_eid + b_higher_eid + no_diff_eid + a_higher_eid + b_higher_eid
        elif phase == 'predict':
            eid_list = list(seqFeature_df.index.intersection(read_label_fn_predict(label_fn)))
            
        seqFeature_df = seqFeature_df.loc[eid_list]
        
        
        # trans-feature
        sample_a = sample.split('_')[0]
        sample_b = sample.split('_')[1]
    
        #***direction***#
        if phase == 'train':
            geneExp_a = read_geneExp(sample_a, geneExp_absmax, nrow=int(len(eid_list)/2), RBPexp_dir=RBPexp_dir)
            geneExp_b = read_geneExp(sample_b, geneExp_absmax, nrow=int(len(eid_list)/2), RBPexp_dir=RBPexp_dir)
            geneExp_left = np.vstack((geneExp_a, geneExp_b))
            geneExp_right = np.vstack((geneExp_b, geneExp_a))
        elif phase == 'predict':
            geneExp_a = read_geneExp(sample_a, geneExp_absmax, nrow=len(eid_list), phase='predict', RBPexp_dir=RBPexp_dir)
            geneExp_b = read_geneExp(sample_b, geneExp_absmax, nrow=len(eid_list), phase='predict', RBPexp_dir=RBPexp_dir)
            
        if phase == 'train':
            X = np.hstack((seqFeature_df, geneExp_left, geneExp_right))
            Y = np_utils.to_categorical(
                np.asarray([0.]*len(no_diff_eid) + [1.]*len(a_higher_eid) + [2.]*len(b_higher_eid) + 
                [0.]*len(no_diff_eid) + [2.]*len(a_higher_eid) + [1.]*len(b_higher_eid)))
            rownames = np.asarray([sample + '-' + x for x in eid_list])
            geneExp_simple_colnames = [gene for gene in geneExp_colnames]
            colnames = np.hstack([seqFeature_df.columns.values, geneExp_colnames])
        elif phase == 'predict':
            X = np.hstack((seqFeature_df, geneExp_a, geneExp_b))
            rownames = np.asarray([sample + '-' + x for x in eid_list])
            geneExp_simple_colnames = [gene for gene in geneExp_colnames]
            colnames = np.hstack([seqFeature_df.columns.values, geneExp_colnames])
        
        if phase == 'train':
            return X, Y, rownames, colnames
        elif phase == 'predict':
            return X, rownames, colnames
    
    else:
        if phase == 'train':
            pos, neg = read_label_fn(label_fn, CIRIdeepA=CIRIdeepA)
            pos_eid = list(seqFeature_df.index.intersection(pos.keys()))
            neg_eid = list(seqFeature_df.index.intersection(neg.keys()))
            eid_list = pos_eid + neg_eid
        elif phase == 'predict':
            eid_list = list(seqFeature_df.index.intersection(read_label_fn_predict(label_fn)))
        
        # cis-feature
        seqFeature_df = seqFeature_df.loc[eid_list]
        
        # trans-feature
        sample_a = sample.split(sep)[0]
        sample_b = sample.split(sep)[1]
        geneExp_a = read_geneExp(sample_a, geneExp_absmax, nrow=len(eid_list), phase=phase, RBPexp_dir=RBPexp_dir)
        geneExp_b = read_geneExp(sample_b, geneExp_absmax, nrow=len(eid_list), phase=phase, RBPexp_dir=RBPexp_dir)
        
        # splicing amount 
        splicing_amount_a, splicing_amount_b, splicing_amount_colnames = read_splicing_amount(sample_a, sample_b, splicing_max, eid_list, phase=phase, splicing_dir=splicing_dir)
        
        X = np.hstack((seqFeature_df, geneExp_a, geneExp_b, splicing_amount_a, splicing_amount_b))
        if phase == 'train':
            Y = np.asarray([1.]*len(pos_eid) + [0.]*len(neg_eid))
        
        rownames = np.asarray([sample + '-' + x for x in eid_list])
        geneExp_simple_colnames = [sample + '-' + gene for sample in [sample_a, sample_b] for gene in geneExp_colnames]
        colnames = np.hstack([seqFeature_df.columns.values, geneExp_simple_colnames, ['splicing_a', 'splicing_b']])
        
        if phase == 'train':
            return X, Y, rownames, colnames
        elif phase == 'predict':
            return X, rownames, colnames


def write_current_pred(test_data, y_pred, fn, CIRIdeepA, phase='train'):

    if phase == 'predict':
        if CIRIdeepA:
            with open(fn, 'w') as fo:
                for i in range(len(test_data['rownames'])):
                    fo.write('\t'.join([test_data['rownames'][i], str(y_pred[i, 0]), str(y_pred[i, 1]), str(y_pred[i, 2])]) + '\n')
        else:
            with open(fn, 'w') as fo:
                for i in range(len(test_data['rownames'])):
                    fo.write('\t'.join([test_data['rownames'][i], str(y_pred[i, 0])]) + '\n')
    elif phase == 'train':
        if CIRIdeepA:
            with open(fn, 'w') as fo:
                for i in range(len(test_data['rownames'])):
                    fo.write('\t'.join([test_data['rownames'][i], str(test_data['Y'][i, 0]), str(test_data['Y'][i, 1]), str(test_data['Y'][i, 2]), str(y_pred[i, 0]), str(y_pred[i, 1]), str(y_pred[i, 2])]) + '\n')
        else:
            with open(fn, 'w') as fo:
                for i in range(len(test_data['rownames'])):
                    fo.write('\t'.join([test_data['rownames'][i], str(test_data['Y'][i]), str(y_pred[i, 0])]) + '\n')

    return

def plot_auroc(roc_val, roc_test, loss_training, loss_eval, loss_test, test_freq, odir):
    
    # plot
    plt.figure(figsize=(6,6))
    plt.title('Evaluation/Test ROC')
    plt.plot(range(1, len(roc_val) + 1), roc_val)
    plt.plot(range(test_freq, len(roc_val) + 1, test_freq), roc_test)
    plt.ylabel('Auroc')
    plt.xlabel('Step')
    plt.savefig(os.path.join(odir, 'evaluation test roc.png'))
    
    # plot loss
    plt.figure(figsize=(6,6))
    plt.title('Training/Evaluation/Test loss')
    plt.plot(range(1, len(loss_training) + 1), loss_training)
    plt.plot(range(1, len(loss_eval) + 1), loss_eval)
    plt.plot(range(test_freq, len(loss_eval) + 1, test_freq), loss_test)
    plt.ylabel('Loss')
    plt.xlabel('Step')
    plt.savefig(os.path.join(odir, 'loss_training_eval_test.png'))
    return

def read_splicing_amount_max(fn):
    df = pd.read_csv(fn, index_col=0, sep='\t')
    return df

def eval_continuous_label_auc(y_true, y_pred, outputdir, prefix):
    auroc = metrics.roc_auc_score(y_true, y_pred)
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred)
    np.save(os.path.join(outputdir, 'y_true.npy'), y_true)
    np.save(os.path.join(outputdir, 'y_pred.npy'), y_pred)
    roc_auc = metrics.auc(fpr, tpr)
    plt.figure(figsize=(6,6))
    plt.title('Test AUROC')
    plt.plot(fpr, tpr, 'b', label = 'Test AUC = %0.3f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.savefig(os.path.join(outputdir, prefix+'.png'))
    return

def read_roc(odir):
    fn = os.path.join(odir, 'validation_test_auroc.txt')
    roc_val = []
    roc_test = []
    with open(fn, 'r') as f:
        for line in f:
            if line.startswith('v'):
                continue
            content = line.strip().split('\t')
            roc_val.append(content[0])
            if content[1] != 'none':
                roc_test.append(content[1])
    return roc_val, roc_test

def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='', splicing_dir='', splicing_amount=True, CIRIdeepA=False):
    
    if not os.path.isdir(odir):
        os.mkdir(odir)
    
    print('output directory: %s' % odir)
    print('input train list: %s' % fn)
    train_list = read_train_list(fn)
    
adaZ-9's avatar
adaZ-9 committed
649
650
    m = clf(**config.architecture, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA)
    test_data = {'X':np.empty((0, m.n_in), dtype='float32'), 'Y':np.asarray([], dtype='float32'), 'rownames':np.asarray([], dtype='str')}
adaZ-9's avatar
adaZ-9 committed
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
    test_data_lst = {'X':[], 'Y':[], 'rownames':[]}
    
    print('Loading features...')
    seqFeature_df = read_sequence_feature(seqFeature)
    geneExp_absmax, geneExp_colnames = read_geneExp_absmax(geneExp_absmax)
    
    if splicing_amount:
        splicing_max = read_splicing_amount_max(splicing_max)
    else:
        splicing_max = []
    
    print('Training start...')
    data_batch_generator = generate_data_batch(
        train_list, test_data_lst, seqFeature_df, geneExp_absmax, geneExp_colnames, odir, RBP_dir, splicing_dir=splicing_dir, splicing_max=splicing_max, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA)
    
    # initiate
    roc_val = []
    roc_test = []
    loss_train_lst = []
    acc_train_lst = []
    loss_eval_lst = []
    acc_eval_lst = []
    loss_test_lst = []
    acc_test_lst = []
    best_roc = 0
    best_loss = float('Inf')
    best_loss_mean = float('Inf')
    patience_loss_mean = 0
    patience_loss_mean_limit = 10
    patience_limit = config.patience_limit
    n_batch = 0
    patience = 0
    step = 0
    test_freq = math.ceil(len(train_list.index)/4)
    print('step', 'auroc_eval', 'aupr_eval', 'loss_eval', 'acc_eval', 'loss_train', 'acc_train', sep='\t', file=open(odir + '/roc_pr_losseval_losstrain.log', 'a+'))
    print('step', 'patience', 'auroc_test', 'aupr_test', 'loss_test', 'acc_test', sep='\t', file=open(odir + '/roc_pr_loss_test.log', 'a+'))
    
    # training start
    for data_batch in data_batch_generator:
        
        step += 1
                    
        X = data_batch['X']
        
        Y = data_batch['Y']
        rownames = data_batch['rownames']
        
        # train and evaluation
        inputdata = {'X':X, 'Y':Y, 'rownames':rownames}
        X_train, X_val, Y_train, Y_val, rownames_train, rownames_val = split_eval_train_data(inputdata, n_fold=config.n_fold)
        
        inputdata = {'X_train':X_train, 'X_val':X_val, 'Y_train':Y_train, 'Y_val':Y_val, 'rownames_train':rownames_train, 'rownames_val':rownames_val}
        print('step %i optimization start.' % step)
adaZ-9's avatar
adaZ-9 committed
704
        best_metrics = m.fit(inputdata)
adaZ-9's avatar
adaZ-9 committed
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        if not CIRIdeepA:
            loss_train_lst.append(best_metrics['loss_train'][0])
            acc_train_lst.append(best_metrics['loss_train'][1])
            loss_eval_lst.append(best_metrics['loss_eval'][0])
            acc_eval_lst.append(best_metrics['loss_eval'][1])
            print(step, best_metrics['roc'], best_metrics['aupr'], best_metrics['loss_eval'][0], best_metrics['loss_eval'][1], best_metrics['loss_train'][0], best_metrics['loss_train'][1], sep='\t', file=open(odir + '/roc_pr_losseval_losstrain.log', 'a+'))
            roc_val.append(best_metrics['roc'])
        else:
            loss_train_lst.append(best_metrics['loss_train'])
            acc_train_lst.append(best_metrics['acc_train'])
            loss_eval_lst.append(best_metrics['loss_val'])
            acc_eval_lst.append(best_metrics['acc_val'])
            print('%d\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f' % 
                (step, best_metrics['precision_score_macro'], best_metrics['precision_score_weighted'], 
                best_metrics['recall_macro'], best_metrics['recall_weighted'], 
                best_metrics['F1_macro'], best_metrics['F1_weighted'], 
                best_metrics['loss_val'], best_metrics['acc_val'], 
                best_metrics['loss_train'], best_metrics['acc_train']), 
                file=open(odir + '/roc_pr_losseval_losstrain.log', 'a+'))

        # metrics of test dataset
        if step % test_freq == 0:
                
            if step <= test_freq:
                test_data_lst['X'] = np.concatenate(test_data_lst['X'])
                test_data_lst['Y'] = np.concatenate(test_data_lst['Y'])
                test_data_lst['rownames'] = np.concatenate(test_data_lst['rownames'])
            
adaZ-9's avatar
adaZ-9 committed
733
734
            loss_test = m.model.evaluate(test_data_lst['X'], test_data_lst['Y'], verbose=1)
            Y_pred = m.predict(test_data_lst)
adaZ-9's avatar
adaZ-9 committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
            
            if not CIRIdeepA:
                auroc_test = metrics.roc_auc_score(test_data_lst['Y'], Y_pred)
                aupr_test = metrics.average_precision_score(test_data_lst['Y'], Y_pred)
                roc_test.append(auroc_test)
                loss_test_lst.append(loss_test[0])
                acc_test_lst.append(loss_test[1])
                    
                # write predict result
                write_current_pred(test_data_lst, Y_pred, os.path.join(odir, 'step_{0}_pred.txt'.format(step)), CIRIdeepA, phase='train')
                
                # plot auroc
                plot_auroc(roc_val, roc_test, loss_train_lst, loss_eval_lst, loss_test_lst, test_freq, odir)
            
            else:
                Y_predict_categorical = (Y_pred == Y_pred.max(axis=1, keepdims=True)).astype('float')
                precision_score_macro = metrics.precision_score(test_data_lst['Y'], Y_predict_categorical, average="macro")
                precision_score_weighted = metrics.precision_score(test_data_lst['Y'], Y_predict_categorical, average="weighted")
                recall_macro = metrics.recall_score(test_data_lst['Y'], Y_predict_categorical, average='macro')
                recall_weighted = metrics.recall_score(test_data_lst['Y'], Y_predict_categorical, average="weighted")
                F1_macro = metrics.f1_score(test_data_lst['Y'], Y_predict_categorical, average='macro')
                F1_weighted = metrics.f1_score(test_data_lst['Y'], Y_predict_categorical, average='weighted')
                loss_test_lst.append(loss_test[0])
                acc_test_lst.append(loss_test[1])
                
                write_current_pred(test_data_lst, Y_pred, os.path.join(odir, 'step_{0}_pred.txt'.format(step)), CIRIdeepA, phase='train')
                
                # plot auroc
                plot_auroc(acc_eval_lst, acc_test_lst, loss_train_lst, loss_eval_lst, loss_test_lst, test_freq, odir)
            
            # compare the metrics
            if loss_test[0] < best_loss:
                best_loss = loss_test[0]
adaZ-9's avatar
adaZ-9 committed
768
                m.model.save(os.path.join(odir, str(step)+'_model_weight.h5'))
adaZ-9's avatar
adaZ-9 committed
769
770
771
                patience = 0
            else:
                patience += 1
adaZ-9's avatar
adaZ-9 committed
772
                m.model.save(os.path.join(odir, str(step)+'_model_weight.h5'))
adaZ-9's avatar
adaZ-9 committed
773
774
775
                if patience > patience_limit:
                    print('patience > patience_limit. Early Stop.')
                    break
adaZ-9's avatar
adaZ-9 committed
776
            del m
adaZ-9's avatar
adaZ-9 committed
777
            K.clear_session()
adaZ-9's avatar
adaZ-9 committed
778
779
            m = clf(**config.architecture, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA)
            m.model.load_weights(os.path.join(odir, str(step)+'_model_weight.h5'))
adaZ-9's avatar
adaZ-9 committed
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
            
            # print test result 
            print('*' * 50)
            print('test result:')
            if not CIRIdeepA:
                print('step %i: auroc-%.3f; aupr-%.3f; best_loss=%.3f    patience-%i' % (step, auroc_test, aupr_test, best_loss, patience))
                print('*' * 50)
                print(step, patience, auroc_test, aupr_test, loss_test[0], loss_test[1], sep='\t', file=open(odir + '/roc_pr_loss_test.log', 'a+'))
            else:
                print('step %i: acc-%.3f; loss=%.3f; best loss=%.3f; patience-%i' % (step, loss_test[1], loss_test[0], best_loss, patience))
                print('*' * 50)
                print('%d\t%d\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f' % 
                    (step, patience, precision_score_macro, precision_score_weighted, 
                    recall_macro, recall_weighted, 
                    F1_macro, F1_weighted, 
                    loss_test[0], loss_test[1]), 
                    file=open(odir + '/roc_pr_loss_test.log', 'a+'))
    
adaZ-9's avatar
adaZ-9 committed
798
799
    m.model.save(os.path.join(odir, 'final_model_weight.h5'))
    Y_pred = m.predict(test_data)
adaZ-9's avatar
adaZ-9 committed
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
    auroc_test = metrics.roc_auc_score(test_data['Y'], Y_pred)
    print('auroc after 10 epoch:', auroc_test)
    
    eval_continuous_label_auc(test_data['Y'], Y_pred, odir, 'auroc')
    
    return


def main():
    
    prog = 'CIRI-deep'
    description = 'CIRI-deep predicts differentially spliced circRNAs between biological samples. This program can be used for training and predicting.'
    epilog = 'Zihan Zhou <zhouzihan2018m@big.ac.cn>'
    argparser = ap.ArgumentParser(prog=prog, description=description, epilog=epilog)
    subparser = argparser.add_subparsers(dest='subcommand')
    
    train = subparser.add_parser('train', help='Train CIRI-deep or CIRI-deep(A) using customized data.')
    predict = subparser.add_parser('predict', help='Predict differentially spliced circRNA using CIRI-deep.')
    
    predict.add_argument('-predict_list', help='file to predict.', dest='predict_list', required=True)
    predict.add_argument('-model_path', help='model path.', dest='model_path', required=True)
    predict.add_argument('-outdir', help='output directory.', dest='outdir', required=True)
    predict.add_argument('-geneExp_absmax', help='gene expression max value.', dest='geneExp_absmax', required=True)
    predict.add_argument('-seqFeature', help='cis feature for all circular RNA.', dest='seqFeature', required=True)
    predict.add_argument('-RBP_dir', help='RBP directory.', dest='RBP_dir', required=True)
    predict.add_argument('-splicing_max', help='splicing amount max value.', default='', dest='splicing_max')
    predict.add_argument('-splicing_dir', help='splicing_amount directory.', default='', dest='splicing_dir')
    predict.add_argument('-sep', help='seperate two samples.', default='_', dest='sep')
    predict.add_argument('--CIRIdeepA', action='store_true', dest='CIRIdeepA')

    train.add_argument('-train_list', help='file to predict.', dest='train_list', required=True)
    train.add_argument('-outdir', help='output directory.', dest='outdir', required=True)
    train.add_argument('-geneExp_absmax', help='gene expression max value.', dest='geneExp_absmax', required=True)
    train.add_argument('-seqFeature', help='cis feature for all circular RNA.', dest='seqFeature', required=True)
    train.add_argument('-splicing_max', help='splicing amount max value.', dest='splicing_max', default='')
    train.add_argument('-RBP_dir', help='RBP directory.', dest='RBP_dir', required=True)
    train.add_argument('-splicing_dir', help='splicing_amount directory.', dest='splicing_dir', default='')
    train.add_argument('--CIRIdeepA', action='store_true', dest='CIRIdeepA')
        
    args = argparser.parse_args()
    subcommand = args.subcommand
    
    if subcommand == 'train':
        train_list = args.train_list
        outdir = args.outdir
        geneExp_absmax = args.geneExp_absmax
        seqFeature = args.seqFeature
        splicing_max = args.splicing_max
        splicing_dir = args.splicing_dir
        RBP_dir = args.RBP_dir
        CIRIdeepA = args.CIRIdeepA
        if CIRIdeepA:
            splicing_amount = False
        else:
            splicing_amount = True
            
        train_model(train_list, outdir, geneExp_absmax, seqFeature, RBP_dir, splicing_max=splicing_max, splicing_dir=splicing_dir, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA)
    
    if subcommand == 'predict':
        geneExp_absmax = args.geneExp_absmax
        seqFeature = args.seqFeature
        splicing_max = args.splicing_max
        test_fn = args.predict_list
        model_path = args.model_path
        out_dir = args.outdir
        RBP_dir = args.RBP_dir
        splicing_dir = args.splicing_dir
        sep = args.sep
        CIRIdeepA = args.CIRIdeepA
        if CIRIdeepA:
            splicing_amount = False
        else:
            splicing_amount = True
            
        phase = 'predict'
        
        print('Loading file from %s' % seqFeature)
        seqFeature_df = read_sequence_feature(seqFeature)
        
        print('Loading gene expression from %s' % geneExp_absmax)
        geneExp_absmax, geneExp_colnames = read_geneExp_absmax(geneExp_absmax)
        
        if splicing_amount:
            print('Loading splicing amount feature from %s' % splicing_max)
            splicing_max = read_splicing_amount_max(splicing_max)
        
        print('Loading model weight from %s' % model_path)
        clf = load_model(model_path)
        test_list = read_train_list(test_fn)
        
        print('Prediction...')
        for i in range(len(test_list)):
            inputfile = test_list.label_fn[i]
            sample = test_list.index[i]
            X_tmp, rownames_tmp, colnames = construct_data_from_label(inputfile, 
              sample, seqFeature_df, geneExp_absmax, geneExp_colnames, RBP_dir, splicing_max=splicing_max, splicing_amount=splicing_amount, phase=phase, splicing_dir=splicing_dir, sep=sep, CIRIdeepA=CIRIdeepA)
            test_data = {'X':X_tmp, 'rownames':rownames_tmp}
            y_pred = clf.predict([test_data['X']])
            write_current_pred(test_data, y_pred, '%s/%s.txt' % (out_dir, sample), CIRIdeepA, phase='predict')
            
    return 

if __name__ == '__main__':
    main()