Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
CIRI-deep_tensorflow
Commits
a2a9ae7a
Commit
a2a9ae7a
authored
Jul 25, 2023
by
adaZ-9
Browse files
modified
parent
f1f59b87
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
31 deletions
+20
-31
CIRIdeep.py
CIRIdeep.py
+20
-31
No files found.
CIRIdeep.py
View file @
a2a9ae7a
...
...
@@ -15,7 +15,7 @@ from keras.utils import np_utils
import
math
import
config
class
SiameseNet
():
class
clf
():
def
__init__
(
self
,
n_in
=
6523
,
hidden
=
[
1200
,
500
,
300
,
200
],
drop_out
=
[
0
,
0
,
0
,
0
,
0
],
batch_size
=
512
,
kernel_initializer
=
'glorot_normal'
,
learning_rate
=
0.001
,
splicing_amount
=
True
,
CIRIdeepA
=
False
):
...
...
@@ -69,7 +69,6 @@ class SiameseNet():
return
Y_pred
#######################################
def
generate_data_batch
(
train_list
,
test_data
,
seqFeature_df
,
geneExp_absmax
,
geneExp_colnames
,
odir
,
RBP_dir
,
splicing_max
=
''
,
splicing_dir
=
''
,
splicing_amount
=
False
,
n_epoch
=
100
,
CIRIdeepA
=
False
):
'''
generate dataset for training while fill test data
...
...
@@ -141,7 +140,6 @@ def generate_data_batch(train_list, test_data, seqFeature_df, geneExp_absmax, ge
return
###################################
def
split_testdata_out
(
inputdata
,
circid_test
=
[],
test_prop
=
0.05
,
random_seed
=
12345
,
CIRIdeepA
=
False
):
'''
...
...
@@ -156,7 +154,7 @@ def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12
if
not
CIRIdeepA
:
# half pos half neg in test data
test_size
=
int
(
len
(
rownames
)
*
test_prop
*
0.5
)
## 要不要取pos和neg的最大?
test_size
=
int
(
len
(
rownames
)
*
test_prop
*
0.5
)
pos_idx
=
np
.
where
(
Y
==
1
)[
0
]
neg_idx
=
np
.
where
(
Y
==
0
)[
0
]
...
...
@@ -195,7 +193,6 @@ def split_testdata_out(inputdata, circid_test=[], test_prop=0.05, random_seed=12
return
X_train
,
Y_train
,
X_test
,
Y_test
,
rownames_train
,
rownames_test
######################
def
split_eval_train_data
(
inputdata
,
n_fold
=
5
):
#######
'''
...
...
@@ -224,9 +221,6 @@ def split_eval_train_data(inputdata, n_fold=5): #######
return
X_train
,
X_val
,
Y_train
,
Y_val
,
rownames_train
,
rownames_val
###################################
# idx_inbatch_list = split_data_into_balanced_minibatch(X_a_train, X_b_train, Y_train, rownames_train)
def
split_data_into_balanced_minibatch
(
Y_train
,
batch_size
=
64
,
pos_prop
=
0.5
,
CIRIdeepA
=
False
):
'''
...
...
@@ -286,7 +280,7 @@ def train_on_balanced_batch(model, inputdata, batch_size=64, validation_freq=10,
n_batch
=
len
(
idx_inbatch_list
)
# train on batch for n times
for
i
in
range
(
0
,
n_batch
):
# train_on_batch update weight?
for
i
in
range
(
0
,
n_batch
):
X_train
,
Y_train
,
X_val
,
Y_val
,
rownames_train
,
rownames_val
=
\
inputdata
[
'X_train'
],
inputdata
[
'Y_train'
],
inputdata
[
'X_val'
],
inputdata
[
'Y_val'
],
inputdata
[
'rownames_train'
],
inputdata
[
'rownames_val'
]
...
...
@@ -294,12 +288,11 @@ def train_on_balanced_batch(model, inputdata, batch_size=64, validation_freq=10,
print
(
'loss on batch%i: %.3f'
%
(
i
,
loss_on_batch
[
0
]))
print
(
'accuracy on batch%i: %.3f'
%
(
i
,
loss_on_batch
[
1
]))
# validation: loss_val, loss_train, roc, pr # validation change the weight? why the loss becomes unstable after evaluate
if
i
==
n_batch
-
1
:
Y_predict
=
model
.
predict
(
X_val
)
eval_val
=
model
.
evaluate
(
X_val
,
Y_val
,
verbose
=
1
)
# is eval_val a loss?
eval_train
=
model
.
evaluate
(
X_train
,
Y_train
,
verbose
=
1
)
# loss unstable # LeakyReLU better than ReLU, prevent the nan occur
eval_val
=
model
.
evaluate
(
X_val
,
Y_val
,
verbose
=
1
)
eval_train
=
model
.
evaluate
(
X_train
,
Y_train
,
verbose
=
1
)
if
not
CIRIdeepA
:
roc
=
metrics
.
roc_auc_score
(
Y_val
,
Y_predict
)
...
...
@@ -349,7 +342,7 @@ def read_label_fn(label_fn, min_read_cov=20, significance=0.1, CIRIdeepA=False):
if
firstline
:
header
=
{
ele
[
i
]:
i
for
i
in
range
(
len
(
ele
))}
# 列名与列名index对应的字典
header
=
{
ele
[
i
]:
i
for
i
in
range
(
len
(
ele
))}
firstline
=
False
continue
...
...
@@ -386,7 +379,7 @@ def read_label_fn(label_fn, min_read_cov=20, significance=0.1, CIRIdeepA=False):
if
firstline
:
header
=
{
ele
[
i
]:
i
for
i
in
range
(
len
(
ele
))}
# 列名与列名index对应的字典
header
=
{
ele
[
i
]:
i
for
i
in
range
(
len
(
ele
))}
firstline
=
False
continue
...
...
@@ -443,10 +436,9 @@ def read_geneExp_absmax(fn):
def
read_geneExp
(
sample
,
geneExp_absmax
,
nrow
=
1
,
phase
=
'train'
,
RBPexp_dir
=
''
):
# geneExp_fn = os.path.join('/xtdisk/gaoyuan_group/zhouzh/ProjectDeepLearning/Training/TrainingResource_afterfilter/RBPexp', sample+'_rpb.csv')
geneExp_fn
=
os
.
path
.
join
(
RBPexp_dir
,
sample
+
'_rpb.csv'
)
df
=
pd
.
read_csv
(
geneExp_fn
,
index_col
=
0
,
sep
=
'
\t
'
).
transpose
()
vec
=
df
.
values
.
flatten
()
/
geneExp_absmax
# 每个基因的表达比上这个基因的最大值
vec
=
df
.
values
.
flatten
()
/
geneExp_absmax
if
phase
==
'predict'
:
vec
[
vec
>
1
]
=
1
mat
=
np
.
tile
(
vec
,
(
nrow
,
1
))
...
...
@@ -457,7 +449,6 @@ def read_splicing_amount(sample1, sample2, splicing_max, eid_list, phase='train'
splicing_amount_max_eidlist
=
splicing_max
.
loc
[
eid_list
]
# splicing_dir = '/xtdisk/gaoyuan_group/zhouzh/ProjectDeepLearning/Training/TrainingResource_afterfilter/SplicingAmount'
splicing_amount_fn1
=
os
.
path
.
join
(
splicing_dir
,
sample1
+
'.output'
)
splicing_amount_fn2
=
os
.
path
.
join
(
splicing_dir
,
sample2
+
'.output'
)
...
...
@@ -595,7 +586,6 @@ def plot_auroc(roc_val, roc_test, loss_training, loss_eval, loss_test, test_freq
plt
.
title
(
'Evaluation/Test ROC'
)
plt
.
plot
(
range
(
1
,
len
(
roc_val
)
+
1
),
roc_val
)
plt
.
plot
(
range
(
test_freq
,
len
(
roc_val
)
+
1
,
test_freq
),
roc_test
)
# plt.legend(loc = 'lower right')
plt
.
ylabel
(
'Auroc'
)
plt
.
xlabel
(
'Step'
)
plt
.
savefig
(
os
.
path
.
join
(
odir
,
'evaluation test roc.png'
))
...
...
@@ -656,8 +646,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
print
(
'input train list: %s'
%
fn
)
train_list
=
read_train_list
(
fn
)
siam_model
=
SiameseNet
(
**
config
.
architecture
,
splicing_amount
=
splicing_amount
,
CIRIdeepA
=
CIRIdeepA
)
test_data
=
{
'X'
:
np
.
empty
((
0
,
siam_model
.
n_in
),
dtype
=
'float32'
),
'Y'
:
np
.
asarray
([],
dtype
=
'float32'
),
'rownames'
:
np
.
asarray
([],
dtype
=
'str'
)}
m
=
clf
(
**
config
.
architecture
,
splicing_amount
=
splicing_amount
,
CIRIdeepA
=
CIRIdeepA
)
test_data
=
{
'X'
:
np
.
empty
((
0
,
m
.
n_in
),
dtype
=
'float32'
),
'Y'
:
np
.
asarray
([],
dtype
=
'float32'
),
'rownames'
:
np
.
asarray
([],
dtype
=
'str'
)}
test_data_lst
=
{
'X'
:[],
'Y'
:[],
'rownames'
:[]}
print
(
'Loading features...'
)
...
...
@@ -691,7 +681,6 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
n_batch
=
0
patience
=
0
step
=
0
# test_freq = config.test_freq
test_freq
=
math
.
ceil
(
len
(
train_list
.
index
)
/
4
)
print
(
'step'
,
'auroc_eval'
,
'aupr_eval'
,
'loss_eval'
,
'acc_eval'
,
'loss_train'
,
'acc_train'
,
sep
=
'
\t
'
,
file
=
open
(
odir
+
'/roc_pr_losseval_losstrain.log'
,
'a+'
))
print
(
'step'
,
'patience'
,
'auroc_test'
,
'aupr_test'
,
'loss_test'
,
'acc_test'
,
sep
=
'
\t
'
,
file
=
open
(
odir
+
'/roc_pr_loss_test.log'
,
'a+'
))
...
...
@@ -712,7 +701,7 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
inputdata
=
{
'X_train'
:
X_train
,
'X_val'
:
X_val
,
'Y_train'
:
Y_train
,
'Y_val'
:
Y_val
,
'rownames_train'
:
rownames_train
,
'rownames_val'
:
rownames_val
}
print
(
'step %i optimization start.'
%
step
)
best_metrics
=
siam_model
.
fit
(
inputdata
)
best_metrics
=
m
.
fit
(
inputdata
)
if
not
CIRIdeepA
:
loss_train_lst
.
append
(
best_metrics
[
'loss_train'
][
0
])
acc_train_lst
.
append
(
best_metrics
[
'loss_train'
][
1
])
...
...
@@ -741,8 +730,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
test_data_lst
[
'Y'
]
=
np
.
concatenate
(
test_data_lst
[
'Y'
])
test_data_lst
[
'rownames'
]
=
np
.
concatenate
(
test_data_lst
[
'rownames'
])
loss_test
=
siam_model
.
model
.
evaluate
(
test_data_lst
[
'X'
],
test_data_lst
[
'Y'
],
verbose
=
1
)
Y_pred
=
siam_model
.
predict
(
test_data_lst
)
loss_test
=
m
.
model
.
evaluate
(
test_data_lst
[
'X'
],
test_data_lst
[
'Y'
],
verbose
=
1
)
Y_pred
=
m
.
predict
(
test_data_lst
)
if
not
CIRIdeepA
:
auroc_test
=
metrics
.
roc_auc_score
(
test_data_lst
[
'Y'
],
Y_pred
)
...
...
@@ -776,18 +765,18 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
# compare the metrics
if
loss_test
[
0
]
<
best_loss
:
best_loss
=
loss_test
[
0
]
siam_model
.
model
.
save
(
os
.
path
.
join
(
odir
,
str
(
step
)
+
'_model_weight.h5'
))
m
.
model
.
save
(
os
.
path
.
join
(
odir
,
str
(
step
)
+
'_model_weight.h5'
))
patience
=
0
else
:
patience
+=
1
siam_model
.
model
.
save
(
os
.
path
.
join
(
odir
,
str
(
step
)
+
'_model_weight.h5'
))
m
.
model
.
save
(
os
.
path
.
join
(
odir
,
str
(
step
)
+
'_model_weight.h5'
))
if
patience
>
patience_limit
:
print
(
'patience > patience_limit. Early Stop.'
)
break
del
siam_model
del
m
K
.
clear_session
()
siam_model
=
SiameseNet
(
**
config
.
architecture
,
splicing_amount
=
splicing_amount
,
CIRIdeepA
=
CIRIdeepA
)
siam_model
.
model
.
load_weights
(
os
.
path
.
join
(
odir
,
str
(
step
)
+
'_model_weight.h5'
))
m
=
clf
(
**
config
.
architecture
,
splicing_amount
=
splicing_amount
,
CIRIdeepA
=
CIRIdeepA
)
m
.
model
.
load_weights
(
os
.
path
.
join
(
odir
,
str
(
step
)
+
'_model_weight.h5'
))
# print test result
print
(
'*'
*
50
)
...
...
@@ -806,8 +795,8 @@ def train_model(fn, odir, geneExp_absmax, seqFeature, RBP_dir, splicing_max='',
loss_test
[
0
],
loss_test
[
1
]),
file
=
open
(
odir
+
'/roc_pr_loss_test.log'
,
'a+'
))
siam_model
.
model
.
save
(
os
.
path
.
join
(
odir
,
'final_model_weight.h5'
))
Y_pred
=
siam_model
.
predict
(
test_data
)
m
.
model
.
save
(
os
.
path
.
join
(
odir
,
'final_model_weight.h5'
))
Y_pred
=
m
.
predict
(
test_data
)
auroc_test
=
metrics
.
roc_auc_score
(
test_data
[
'Y'
],
Y_pred
)
print
(
'auroc after 10 epoch:'
,
auroc_test
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment