train.py 2.08 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import keras
import tensorflow as tf
import data_utils,model_utils


config =  tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
# 64ms, 128ms, 256ms
choonse_time_bin="256ms"
# plain-CNN, ResNet, ResNet-CBAM
choose_model="ResNet-CBAM"

data_set_dir=r"/workspace/binary_distinguish_GRB_by_DL-main/Binary_Distinguish_GRB_Datasetv1/data/dataset_256ms/"
# load and pre-process data (train and validate)
(train_x, train_y, train_info), (val_x, val_y, val_info)=data_utils.get_train_val_data(data_set_dir,choonse_time_bin)

trainBatchSize=32
train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
val_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))
train_ds = train_ds.shuffle(20000).batch(trainBatchSize).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.batch(trainBatchSize).prefetch(tf.data.AUTOTUNE)

model_func_dic={
    "plain-CNN-64ms": model_utils.plain_cnn_64ms,
    "plain-CNN-128ms": model_utils.plain_cnn_128ms,
    "plain-CNN-256ms": model_utils.plain_cnn_256ms,

    "ResNet-64ms": model_utils.resnet_64ms,
    "ResNet-128ms": model_utils.resnet_128ms,
    "ResNet-256ms": model_utils.resnet_256ms,

    "ResNet-CBAM-64ms": model_utils.resnet_CBAM_64ms,
    "ResNet-CBAM-128ms": model_utils.resnet_CBAM_128ms,
    "ResNet-CBAM-256ms": model_utils.resnet_CBAM_256ms,
}
model_func=model_func_dic.get(choose_model+"-"+choonse_time_bin)
from keras import backend as K
K.clear_session()

input_shape, nb_classes = (train_x.shape[1:]), 2
input_layer = keras.layers.Input(shape=input_shape, name='input')

model = model_func(input_layer, nb_classes)
model_name = choose_model+choonse_time_bin

adam = keras.optimizers.Adam(lr=0.0001, beta_1=0.95, beta_2=0.999, epsilon=1e-08)
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])
model.summary()
trainEpochs=20



# model_utils.train_model(model, train_x, train_y, val_x, val_y, trainEpochs, trainBatchSize,modelName=model_name,outDir="gpuout/", binSize=choonse_time_bin)
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=trainEpochs,
    )
print("done")