Commit a2a9ae7a authored by adaZ-9's avatar adaZ-9
Browse files

modified

parent f1f59b87
...@@ -15,7 +15,7 @@ from keras.utils import np_utils ...@@ -15,7 +15,7 @@ from keras.utils import np_utils
import math import math
import config import config
class SiameseNet(): class clf():
def __init__(self, n_in=6523, hidden=[1200, 500, 300, 200], drop_out=[0, 0, 0, 0, 0], batch_size=512, kernel_initializer='glorot_normal', learning_rate=0.001, splicing_amount=True, CIRIdeepA=False): def __init__(self, n_in=6523, hidden=[1200, 500, 300, 200], drop_out=[0, 0, 0, 0, 0], batch_size=512, kernel_initializer='glorot_normal', learning_rate=0.001, splicing_amount=True, CIRIdeepA=False):
...@@ -69,7 +69,6 @@ class SiameseNet(): ...@@ -69,7 +69,6 @@ class SiameseNet():
return Y_pred return Y_pred
#######################################
def generate_data_batch(train_list, test_data, seqFeature_df, geneExp_absmax, geneExp_colnames, odir, RBP_dir, splicing_max='', splicing_dir='', splicing_amount=False, n_epoch=100, CIRIdeepA=False): def generate_data_batch(train_list, test_data, seqFeature_df, geneExp_absmax, geneExp_colnames, odir, RBP_dir, splicing_max='', splicing_dir='', splicing_amount=False, n_epoch=100, CIRIdeepA=False):
''' '''
generate dataset for training while fill test data generate dataset for training while fill test data
...@@ -141,7 +140,6 @@ def generate_data_batch(train_list, test_data, seqFeature_df, geneExp_absmax, ge ...@@ -141,7 +140,6 @@ def generate_data_batch(train_list, test_data, seqFeature_df, geneExp_absmax, ge
return return
###################################
def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12345, CIRIdeepA=False): def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12345, CIRIdeepA=False):
''' '''
...@@ -156,7 +154,7 @@ def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12 ...@@ -156,7 +154,7 @@ def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12
if not CIRIdeepA: if not CIRIdeepA:
# half pos half neg in test data # half pos half neg in test data
test_size = int(len(rownames) * test_prop * 0.5) ## 要不要取pos和neg的最大? test_size = int(len(rownames) * test_prop * 0.5)
pos_idx = np.where(Y == 1)[0] pos_idx = np.where(Y == 1)[0]
neg_idx = np.where(Y == 0)[0] neg_idx = np.where(Y == 0)[0]
...@@ -195,7 +193,6 @@ def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12 ...@@ -195,7 +193,6 @@ def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12
return X_train, Y_train, X_test, Y_test, rownames_train, rownames_test return X_train, Y_train, X_test, Y_test, rownames_train, rownames_test
######################
def split_eval_train_data(inputdata, n_fold=5): ####### def split_eval_train_data(inputdata, n_fold=5): #######
''' '''
...@@ -224,9 +221,6 @@ def split_eval_train_data(inputdata, n_fold=5): ####### ...@@ -224,9 +221,6 @@ def split_eval_train_data(inputdata, n_fold=5): #######
return X_train, X_val, Y_train, Y_val, rownames_train, rownames_val return X_train, X_val, Y_train, Y_val, rownames_train, rownames_val
###################################
# idx_inbatch_list = split_data_into_balanced_minibatch(X_a_train, X_b_train, Y_train, rownames_train)
def split_data_into_balanced_minibatch(Y_train, batch_size=64, pos_prop=0.5, CIRIdeepA=False): def split_data_into_balanced_minibatch(Y_train, batch_size=64, pos_prop=0.5, CIRIdeepA=False):
''' '''
...@@ -286,7 +280,7 @@ def train_on_balanced_batch(model, inputdata, batch_size=64, validation_freq=10, ...@@ -286,7 +280,7 @@ def train_on_balanced_batch(model, inputdata, batch_size=64, validation_freq=10,
n_batch = len(idx_inbatch_list) n_batch = len(idx_inbatch_list)
# train on batch for n times # train on batch for n times
for i in range(0, n_batch): # train_on_batch update weight? for i in range(0, n_batch):
X_train, Y_train, X_val, Y_val, rownames_train, rownames_val = \ X_train, Y_train, X_val, Y_val, rownames_train, rownames_val = \
inputdata['X_train'], inputdata['Y_train'], inputdata['X_val'], inputdata['Y_val'], inputdata['rownames_train'], inputdata['rownames_val'] inputdata['X_train'], inputdata['Y_train'], inputdata['X_val'], inputdata['Y_val'], inputdata['rownames_train'], inputdata['rownames_val']
...@@ -294,12 +288,11 @@ def train_on_balanced_batch(model, inputdata, batch_size=64, validation_freq=10, ...@@ -294,12 +288,11 @@ def train_on_balanced_batch(model, inputdata, batch_size=64, validation_freq=10,
print('loss on batch%i: %.3f' % (i, loss_on_batch[0])) print('loss on batch%i: %.3f' % (i, loss_on_batch[0]))
print('accuracy on batch%i: %.3f' % (i, loss_on_batch[1])) print('accuracy on batch%i: %.3f' % (i, loss_on_batch[1]))
# validation: loss_val, loss_train, roc, pr # validation change the weight? why the loss becomes unstable after evaluate
if i == n_batch - 1: if i == n_batch - 1:
Y_predict = model.predict(X_val) Y_predict = model.predict(X_val)
eval_val = model.evaluate(X_val, Y_val, verbose=1) # is eval_val a loss? eval_val = model.evaluate(X_val, Y_val, verbose=1)
eval_train = model.evaluate(X_train, Y_train, verbose=1) # loss unstable # LeakyReLU better than ReLU, prevent the nan occur eval_train = model.evaluate(X_train, Y_train, verbose=1)
if not CIRIdeepA: if not CIRIdeepA:
roc = metrics.roc_auc_score(Y_val, Y_predict) roc = metrics.roc_auc_score(Y_val, Y_predict)
...@@ -349,7 +342,7 @@ def read_label_fn(label_fn, min_read_cov=20, significance=0.1, CIRIdeepA=False): ...@@ -349,7 +342,7 @@ def read_label_fn(label_fn, min_read_cov=20, significance=0.1, CIRIdeepA=False):
if firstline: if firstline:
header = {ele[i]:i for i in range(len(ele))} # 列名与列名index对应的字典 header = {ele[i]:i for i in range(len(ele))}
firstline = False firstline = False
continue continue
...@@ -386,7 +379,7 @@ def read_label_fn(label_fn, min_read_cov=20, significance=0.1, CIRIdeepA=False): ...@@ -386,7 +379,7 @@ def read_label_fn(label_fn, min_read_cov=20, significance=0.1, CIRIdeepA=False):
if firstline: if firstline:
header = {ele[i]:i for i in range(len(ele))} # 列名与列名index对应的字典 header = {ele[i]:i for i in range(len(ele))}
firstline = False firstline = False
continue continue
...@@ -443,10 +436,9 @@ def read_geneExp_absmax(fn): ...@@ -443,10 +436,9 @@ def read_geneExp_absmax(fn):
def read_geneExp(sample, geneExp_absmax, nrow=1, phase='train', RBPexp_dir=''): def read_geneExp(sample, geneExp_absmax, nrow=1, phase='train', RBPexp_dir=''):
# geneExp_fn = os.path.join('/xtdisk/gaoyuan_group/zhouzh/ProjectDeepLearning/Training/TrainingResource_afterfilter/RBPexp', sample+'_rpb.csv')
geneExp_fn = os.path.join(RBPexp_dir, sample + '_rpb.csv') geneExp_fn = os.path.join(RBPexp_dir, sample + '_rpb.csv')
df = pd.read_csv(geneExp_fn, index_col=0, sep='\t').transpose() df = pd.read_csv(geneExp_fn, index_col=0, sep='\t').transpose()
vec = df.values.flatten() / geneExp_absmax # 每个基因的表达比上这个基因的最大值 vec = df.values.flatten() / geneExp_absmax
if phase == 'predict': if phase == 'predict':
vec[vec > 1] = 1 vec[vec > 1] = 1
mat = np.tile(vec, (nrow, 1)) mat = np.tile(vec, (nrow, 1))
...@@ -457,7 +449,6 @@ def read_splicing_amount(sample1, sample2, splicing_max, eid_list, phase='train' ...@@ -457,7 +449,6 @@ def read_splicing_amount(sample1, sample2, splicing_max, eid_list, phase='train'
splicing_amount_max_eidlist = splicing_max.loc[eid_list] splicing_amount_max_eidlist = splicing_max.loc[eid_list]
# splicing_dir = '/xtdisk/gaoyuan_group/zhouzh/ProjectDeepLearning/Training/TrainingResource_afterfilter/SplicingAmount'
splicing_amount_fn1 = os.path.join(splicing_dir, sample1+'.output') splicing_amount_fn1 = os.path.join(splicing_dir, sample1+'.output')
splicing_amount_fn2 = os.path.join(splicing_dir, sample2+'.output') splicing_amount_fn2 = os.path.join(splicing_dir, sample2+'.output')
...@@ -595,7 +586,6 @@ def plot_auroc(roc_val, roc_test, loss_training, loss_eval, loss_test, test_freq ...@@ -595,7 +586,6 @@ def plot_auroc(roc_val, roc_test, loss_training, loss_eval, loss_test, test_freq
plt.title('Evaluation/Test ROC') plt.title('Evaluation/Test ROC')
plt.plot(range(1, len(roc_val) + 1), roc_val) plt.plot(range(1, len(roc_val) + 1), roc_val)
plt.plot(range(test_freq, len(roc_val) + 1, test_freq), roc_test) plt.plot(range(test_freq, len(roc_val) + 1, test_freq), roc_test)
# plt.legend(loc = 'lower right')
plt.ylabel('Auroc') plt.ylabel('Auroc')
plt.xlabel('Step') plt.xlabel('Step')
plt.savefig(os.path.join(odir, 'evaluation test roc.png')) plt.savefig(os.path.join(odir, 'evaluation test roc.png'))
...@@ -656,8 +646,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='', ...@@ -656,8 +646,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
print('input train list: %s' % fn) print('input train list: %s' % fn)
train_list = read_train_list(fn) train_list = read_train_list(fn)
siam_model = SiameseNet(**config.architecture, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA) m = clf(**config.architecture, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA)
test_data = {'X':np.empty((0, siam_model.n_in), dtype='float32'), 'Y':np.asarray([], dtype='float32'), 'rownames':np.asarray([], dtype='str')} test_data = {'X':np.empty((0, m.n_in), dtype='float32'), 'Y':np.asarray([], dtype='float32'), 'rownames':np.asarray([], dtype='str')}
test_data_lst = {'X':[], 'Y':[], 'rownames':[]} test_data_lst = {'X':[], 'Y':[], 'rownames':[]}
print('Loading features...') print('Loading features...')
...@@ -691,7 +681,6 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='', ...@@ -691,7 +681,6 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
n_batch = 0 n_batch = 0
patience = 0 patience = 0
step = 0 step = 0
# test_freq = config.test_freq
test_freq = math.ceil(len(train_list.index)/4) test_freq = math.ceil(len(train_list.index)/4)
print('step', 'auroc_eval', 'aupr_eval', 'loss_eval', 'acc_eval', 'loss_train', 'acc_train', sep='\t', file=open(odir + '/roc_pr_losseval_losstrain.log', 'a+')) print('step', 'auroc_eval', 'aupr_eval', 'loss_eval', 'acc_eval', 'loss_train', 'acc_train', sep='\t', file=open(odir + '/roc_pr_losseval_losstrain.log', 'a+'))
print('step', 'patience', 'auroc_test', 'aupr_test', 'loss_test', 'acc_test', sep='\t', file=open(odir + '/roc_pr_loss_test.log', 'a+')) print('step', 'patience', 'auroc_test', 'aupr_test', 'loss_test', 'acc_test', sep='\t', file=open(odir + '/roc_pr_loss_test.log', 'a+'))
...@@ -712,7 +701,7 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='', ...@@ -712,7 +701,7 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
inputdata = {'X_train':X_train, 'X_val':X_val, 'Y_train':Y_train, 'Y_val':Y_val, 'rownames_train':rownames_train, 'rownames_val':rownames_val} inputdata = {'X_train':X_train, 'X_val':X_val, 'Y_train':Y_train, 'Y_val':Y_val, 'rownames_train':rownames_train, 'rownames_val':rownames_val}
print('step %i optimization start.' % step) print('step %i optimization start.' % step)
best_metrics = siam_model.fit(inputdata) best_metrics = m.fit(inputdata)
if not CIRIdeepA: if not CIRIdeepA:
loss_train_lst.append(best_metrics['loss_train'][0]) loss_train_lst.append(best_metrics['loss_train'][0])
acc_train_lst.append(best_metrics['loss_train'][1]) acc_train_lst.append(best_metrics['loss_train'][1])
...@@ -741,8 +730,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='', ...@@ -741,8 +730,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
test_data_lst['Y'] = np.concatenate(test_data_lst['Y']) test_data_lst['Y'] = np.concatenate(test_data_lst['Y'])
test_data_lst['rownames'] = np.concatenate(test_data_lst['rownames']) test_data_lst['rownames'] = np.concatenate(test_data_lst['rownames'])
loss_test = siam_model.model.evaluate(test_data_lst['X'], test_data_lst['Y'], verbose=1) loss_test = m.model.evaluate(test_data_lst['X'], test_data_lst['Y'], verbose=1)
Y_pred = siam_model.predict(test_data_lst) Y_pred = m.predict(test_data_lst)
if not CIRIdeepA: if not CIRIdeepA:
auroc_test = metrics.roc_auc_score(test_data_lst['Y'], Y_pred) auroc_test = metrics.roc_auc_score(test_data_lst['Y'], Y_pred)
...@@ -776,18 +765,18 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='', ...@@ -776,18 +765,18 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
# compare the metrics # compare the metrics
if loss_test[0] < best_loss: if loss_test[0] < best_loss:
best_loss = loss_test[0] best_loss = loss_test[0]
siam_model.model.save(os.path.join(odir, str(step)+'_model_weight.h5')) m.model.save(os.path.join(odir, str(step)+'_model_weight.h5'))
patience = 0 patience = 0
else: else:
patience += 1 patience += 1
siam_model.model.save(os.path.join(odir, str(step)+'_model_weight.h5')) m.model.save(os.path.join(odir, str(step)+'_model_weight.h5'))
if patience > patience_limit: if patience > patience_limit:
print('patience > patience_limit. Early Stop.') print('patience > patience_limit. Early Stop.')
break break
del siam_model del m
K.clear_session() K.clear_session()
siam_model = SiameseNet(**config.architecture, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA) m = clf(**config.architecture, splicing_amount=splicing_amount, CIRIdeepA=CIRIdeepA)
siam_model.model.load_weights(os.path.join(odir, str(step)+'_model_weight.h5')) m.model.load_weights(os.path.join(odir, str(step)+'_model_weight.h5'))
# print test result # print test result
print('*' * 50) print('*' * 50)
...@@ -806,8 +795,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='', ...@@ -806,8 +795,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
loss_test[0], loss_test[1]), loss_test[0], loss_test[1]),
file=open(odir + '/roc_pr_loss_test.log', 'a+')) file=open(odir + '/roc_pr_loss_test.log', 'a+'))
siam_model.model.save(os.path.join(odir, 'final_model_weight.h5')) m.model.save(os.path.join(odir, 'final_model_weight.h5'))
Y_pred = siam_model.predict(test_data) Y_pred = m.predict(test_data)
auroc_test = metrics.roc_auc_score(test_data['Y'], Y_pred) auroc_test = metrics.roc_auc_score(test_data['Y'], Y_pred)
print('auroc after 10 epoch:', auroc_test) print('auroc after 10 epoch:', auroc_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment