Commit a41fb568 authored by maming's avatar maming
Browse files

Delete train_backup.py

parent 0e81b4a0
Pipeline #3362 canceled with stages
import keras
import tensorflow as tf
import data_utils,model_utils
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
# 64ms, 128ms, 256ms
choonse_time_bin="256ms"
# plain-CNN, ResNet, ResNet-CBAM
choose_model="ResNet-CBAM"
data_set_dir=r"/workspace/binary_distinguish_GRB_by_DL-main/Binary_Distinguish_GRB_Datasetv1/data/dataset_256ms/"
# load and pre-process data (train and validate)
(train_x, train_y, train_info), (val_x, val_y, val_info)=data_utils.get_train_val_data(data_set_dir,choonse_time_bin)
model_func_dic={
"plain-CNN-64ms": model_utils.plain_cnn_64ms,
"plain-CNN-128ms": model_utils.plain_cnn_128ms,
"plain-CNN-256ms": model_utils.plain_cnn_256ms,
"ResNet-64ms": model_utils.resnet_64ms,
"ResNet-128ms": model_utils.resnet_128ms,
"ResNet-256ms": model_utils.resnet_256ms,
"ResNet-CBAM-64ms": model_utils.resnet_CBAM_64ms,
"ResNet-CBAM-128ms": model_utils.resnet_CBAM_128ms,
"ResNet-CBAM-256ms": model_utils.resnet_CBAM_256ms,
}
model_func=model_func_dic.get(choose_model+"-"+choonse_time_bin)
from keras import backend as K
K.clear_session()
input_shape, nb_classes = (train_x.shape[1:]), 2
input_layer = keras.layers.Input(shape=input_shape, name='input')
model = model_func(input_layer, nb_classes)
model_name = choose_model+choonse_time_bin
adam = keras.optimizers.Adam(lr=0.0001, beta_1=0.95, beta_2=0.999, epsilon=1e-08)
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])
model.summary()
trainEpochs=20
trainBatchSize=32
model_utils.train_model(model, train_x, train_y, val_x, val_y, trainEpochs, trainBatchSize,modelName=model_name,outDir="gpuout/", binSize=choonse_time_bin)
print("done")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment