train_model.ipynb 3.77 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# train_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "import tensorflow as tf\n",
    "import data_utils,model_utils\n",
    "\n",
    "config =  tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "session = tf.compat.v1.Session(config=config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 64ms, 128ms, 256ms\n",
    "choonse_time_bin=\"64ms\"\n",
    "# plain-CNN, ResNet, ResNet-CBAM\n",
    "choose_model=\"plain-CNN\"\n",
    "\n",
    "data_set_dir=r\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# load and pre-process data (train and validate)\n",
    "(train_x, train_y, train_info), (val_x, val_y, val_info)=data_utils.get_train_val_data(data_set_dir,choonse_time_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "model_func_dic={\n",
    "    \"plain-CNN-64ms\": model_utils.plain_cnn_64ms,\n",
    "    \"plain-CNN-128ms\": model_utils.plain_cnn_128ms,\n",
    "    \"plain-CNN-256ms\": model_utils.plain_cnn_256ms,\n",
    "\n",
    "    \"ResNet-64ms\": model_utils.resnet_64ms,\n",
    "    \"ResNet-128ms\": model_utils.resnet_128ms,\n",
    "    \"ResNet-256ms\": model_utils.resnet_256ms,\n",
    "\n",
    "    \"ResNet-CBAM-64ms\": model_utils.resnet_CBAM_64ms,\n",
    "    \"ResNet-CBAM-128ms\": model_utils.resnet_CBAM_128ms,\n",
    "    \"ResNet-CBAM-256ms\": model_utils.resnet_CBAM_256ms,\n",
    "}\n",
    "model_func=model_func_dic.get(choose_model+\"-\"+choonse_time_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "K.clear_session()\n",
    "\n",
    "input_shape, nb_classes = (train_x.shape[1:]), 2\n",
    "input_layer = keras.layers.Input(shape=input_shape, name='input')\n",
    "\n",
    "model = model_func(input_layer, nb_classes)\n",
    "model_name = choose_model+choonse_time_bin\n",
    "\n",
    "adam = keras.optimizers.Adam(lr=0.0001, beta_1=0.95, beta_2=0.999, epsilon=1e-08)\n",
    "model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "trainEpochs=1000\n",
    "trainBatchSize=32\n",
    "\n",
    "model_utils.train_model(model, train_x, train_y, val_x, val_y, trainEpochs, trainBatchSize,modelName=model_name,outDir=\"input/out_gpu/\", binSize=choonse_time_bin)\n",
    "print(\"done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}