# **Keras-cv** ## 1.Introduction: 相关环境配置及代码位置 ``` Keras-CV源码及训练代码示例,DTK版本为22.04。TensorFlow版本为2.9,python版本支持3.7、3.8、3.9. 支持的网络有 ResNet101、ResNet50V2、RegNetX064、DarkNet53、DenseNet121、EfficientNetV2S 网络模型代码位于keras-cv/models.训练代码basic_training.py位于examples/training/classification/imagenet/ 使用tfrecord格式的Imagenet数据集.相关网络的训练情况及参数设置位于examples/training/classification/imagenet/ rain_history.json ``` ## 2.Install 目前keras-cv模型版本支持最低要求TF2.9+ ``` 安装DTK版本的TensorFlow2.9 pip install tensorflow-2.9.0+git28e158bd.dtk22042-cp37-cp37m-linux_x86_64.whl 执行python setup.py build python setup.py install ``` ## 3.Train 训练脚本位于keras-cv/examples/training/classification/imagenet/trian_keras_cv.sh 注意basic_training.py中的分布式设置为 默认使用所有检测到的加速卡设备,若要指定几张加速卡参与训练使用export HIP_VISIBLE_DEVICES=x1,x2,x3来指定. ``` 以乌镇平台为例, 1 数据准备 乌镇平台tfrecord格式的ImageNet数据集位置为/public/DL_DATA/imagenet_tfrecord/trrecord 2 创建虚拟环境 这里以anaconda3创建虚拟环境为例, conda create -n tf2.9-keras python==3.7.0, 创建名为tf2.9-keras的虚拟环境,python版本为3.7.0 conda info --envs, 查看已经创建的虚拟环境 conda activate tf2.9-keras, 激活虚拟环境 下载tensorflow-2.9.0+git28e158bd.dtk22042-cp37-cp37m-linux_x86_64.whl安装包,pip install ... 安装dtk22.04版本的TensorFlow2.9 其余所需的三方包与版本为 Package Version ---------------------------- -------------------------- absl-py 1.3.0 astunparse 1.6.3 cachetools 5.2.0 certifi 2022.9.24 charset-normalizer 2.1.1 cycler 0.11.0 dill 0.3.6 etils 0.9.0 flatbuffers 1.12 fonttools 4.38.0 gast 0.4.0 google-auth 2.13.0 google-auth-oauthlib 0.4.6 google-pasta 0.2.0 googleapis-common-protos 1.57.0 grpcio 1.50.0 h5py 3.7.0 idna 3.4 importlib-metadata 5.0.0 importlib-resources 5.10.0 keras 2.9.0 keras-cv 0.3.4 keras-nlp 0.4.0 Keras-Preprocessing 1.1.2 kiwisolver 1.4.4 libclang 14.0.6 Markdown 3.4.1 MarkupSafe 2.1.1 matplotlib 3.5.3 numpy 1.21.6 oauthlib 3.2.2 opt-einsum 3.3.0 packaging 21.3 Pillow 9.3.0 pip 22.2.2 promise 2.3 protobuf 3.19.6 pyasn1 0.4.8 pyasn1-modules 0.2.8 pycocotools 2.0.6 pyparsing 3.0.9 python-dateutil 2.8.2 regex 2022.9.13 requests 2.28.1 requests-oauthlib 1.3.1 rsa 4.9 setuptools 63.4.1 six 1.16.0 tensorboard 2.9.1 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.1 tensorflow 2.9.0+git28e158bd.dtk22042 tensorflow-datasets 4.7.0 tensorflow-estimator 2.9.0 tensorflow-hub 0.12.0 tensorflow-io-gcs-filesystem 0.27.0 tensorflow-metadata 1.11.0 tensorflow-text 2.9.0 termcolor 2.0.1 toml 0.10.2 tqdm 4.64.1 typing_extensions 4.4.0 urllib3 1.26.12 Werkzeug 2.2.2 wheel 0.37.1 wrapt 1.14.1 zipp 3.10.0 环境配置完毕后,cd cd keras-cv/examples/training/classification/imagenet/, 进入训练脚本与训练代码所在位置 ln -s /public/DL_DATA/imagenet_tfrecord/trrecord imagenet, 创建数据集软链接 训练脚本为train_keras_cv.sh 训练脚本中的参数解释 --model_name :所训练的模型名称,所支持的网络名称位于keras-cv/keras_cv/models/__init__.py,注意目前经过DTK22.04训练验证的网络有ResNet101、ResNet50V2、RegNetX064、DarkNet53、DenseNet121、EfficientNetV2S,其余网络尚未训练验证。 --imagenet_path :数据集路径 --backup_path :模型更新路径 --weights_path :模型存储路径 注意:不要将该路径与--backup_path设置为同一个 --tensorboard_path : tensorboard日志存储路径 --use_xla=False :是否使用xla,当前版本使用xla存在错误,设置为False --initial_learning_rate : 起始学习率 --learning_rate_schedule :学习率使用策略 --batch_size :batch_size设置 --epochs :训练迭代多少次 注意:进行断点续训时,第一个Epoch之后,训练日志输出的每个Epoch的step数量存在出入,可忽略,因为实际训练中每个Epoch的训练的step数量是不变的。此外已经经过训练验证的相关网络的数据记录在keras/examples/training/classification/imagenet/train_history.json文件中 ```