{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "# train_model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import keras\n", "import tensorflow as tf\n", "import data_utils,model_utils\n", "\n", "config = tf.compat.v1.ConfigProto()\n", "config.gpu_options.allow_growth = True\n", "session = tf.compat.v1.Session(config=config)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# 64ms, 128ms, 256ms\n", "choonse_time_bin=\"64ms\"\n", "# plain-CNN, ResNet, ResNet-CBAM\n", "choose_model=\"plain-CNN\"\n", "\n", "data_set_dir=r\"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# load and pre-process data (train and validate)\n", "(train_x, train_y, train_info), (val_x, val_y, val_info)=data_utils.get_train_val_data(data_set_dir,choonse_time_bin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "\n", "model_func_dic={\n", " \"plain-CNN-64ms\": model_utils.plain_cnn_64ms,\n", " \"plain-CNN-128ms\": model_utils.plain_cnn_128ms,\n", " \"plain-CNN-256ms\": model_utils.plain_cnn_256ms,\n", "\n", " \"ResNet-64ms\": model_utils.resnet_64ms,\n", " \"ResNet-128ms\": model_utils.resnet_128ms,\n", " \"ResNet-256ms\": model_utils.resnet_256ms,\n", "\n", " \"ResNet-CBAM-64ms\": model_utils.resnet_CBAM_64ms,\n", " \"ResNet-CBAM-128ms\": model_utils.resnet_CBAM_128ms,\n", " \"ResNet-CBAM-256ms\": model_utils.resnet_CBAM_256ms,\n", "}\n", "model_func=model_func_dic.get(choose_model+\"-\"+choonse_time_bin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from keras import backend as K\n", "K.clear_session()\n", "\n", "input_shape, nb_classes = (train_x.shape[1:]), 2\n", "input_layer = keras.layers.Input(shape=input_shape, name='input')\n", "\n", "model = model_func(input_layer, nb_classes)\n", "model_name = choose_model+choonse_time_bin\n", "\n", "adam = keras.optimizers.Adam(lr=0.0001, beta_1=0.95, beta_2=0.999, epsilon=1e-08)\n", "model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "trainEpochs=1000\n", "trainBatchSize=32\n", "\n", "model_utils.train_model(model, train_x, train_y, val_x, val_y, trainEpochs, trainBatchSize,modelName=model_name,outDir=\"input/out_gpu/\", binSize=choonse_time_bin)\n", "print(\"done\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 4 }