README.md 5.42 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
# **Keras-cv**
## 1.Introduction:
相关环境配置及代码位置
zhanggzh's avatar
zhanggzh committed
4

zhanggzh's avatar
zhanggzh committed
5
6
```
Keras-CV源码及训练代码示例,DTK版本为22.04。TensorFlow版本为2.9,python版本支持3.7、3.8、3.9.
zhanggzh's avatar
zhanggzh committed
7
8
9
10
支持的网络有 ResNet101、ResNet50V2、RegNetX064、DarkNet53、DenseNet121、EfficientNetV2S
网络模型代码位于keras-cv/models.训练代码basic_training.py位于examples/training/classification/imagenet/
使用tfrecord格式的Imagenet数据集.相关网络的训练情况及参数设置位于examples/training/classification/imagenet/
rain_history.json
zhanggzh's avatar
zhanggzh committed
11
12
```
## 2.Install
zhanggzh's avatar
zhanggzh committed
13

zhanggzh's avatar
zhanggzh committed
14
目前keras-cv模型版本支持最低要求TF2.9+
zhanggzh's avatar
zhanggzh committed
15

zhanggzh's avatar
zhanggzh committed
16
```
zhanggzh's avatar
zhanggzh committed
17
18
19
安装DTK版本的TensorFlow2.9 pip install tensorflow-2.9.0+git28e158bd.dtk22042-cp37-cp37m-linux_x86_64.whl
执行python setup.py build
   python setup.py install
zhanggzh's avatar
zhanggzh committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
```    

## 3.Train 

训练脚本位于keras-cv/examples/training/classification/imagenet/trian_keras_cv.sh 注意basic_training.py中的分布式设置为
默认使用所有检测到的加速卡设备,若要指定几张加速卡参与训练使用export HIP_VISIBLE_DEVICES=x1,x2,x3来指定.

``` 
以乌镇平台为例,
1 数据准备
乌镇平台tfrecord格式的ImageNet数据集位置为/public/DL_DATA/imagenet_tfrecord/trrecord
2 创建虚拟环境
这里以anaconda3创建虚拟环境为例,
conda create -n tf2.9-keras python==3.7.0,  创建名为tf2.9-keras的虚拟环境,python版本为3.7.0
conda info --envs,  查看已经创建的虚拟环境
conda activate tf2.9-keras, 激活虚拟环境
下载tensorflow-2.9.0+git28e158bd.dtk22042-cp37-cp37m-linux_x86_64.whl安装包,pip install ... 安装dtk22.04版本的TensorFlow2.9
 其余所需的三方包与版本为
 Package                      Version
---------------------------- --------------------------
absl-py                      1.3.0
astunparse                   1.6.3
cachetools                   5.2.0
certifi                      2022.9.24
charset-normalizer           2.1.1
cycler                       0.11.0
dill                         0.3.6
etils                        0.9.0
flatbuffers                  1.12
fonttools                    4.38.0
gast                         0.4.0
google-auth                  2.13.0
google-auth-oauthlib         0.4.6
google-pasta                 0.2.0
googleapis-common-protos     1.57.0
grpcio                       1.50.0
h5py                         3.7.0
idna                         3.4
importlib-metadata           5.0.0
importlib-resources          5.10.0
keras                        2.9.0
keras-cv                     0.3.4
keras-nlp                    0.4.0
Keras-Preprocessing          1.1.2
kiwisolver                   1.4.4
libclang                     14.0.6
Markdown                     3.4.1
MarkupSafe                   2.1.1
matplotlib                   3.5.3
numpy                        1.21.6
oauthlib                     3.2.2
opt-einsum                   3.3.0
packaging                    21.3
Pillow                       9.3.0
pip                          22.2.2
promise                      2.3
protobuf                     3.19.6
pyasn1                       0.4.8
pyasn1-modules               0.2.8
pycocotools                  2.0.6
pyparsing                    3.0.9
python-dateutil              2.8.2
regex                        2022.9.13
requests                     2.28.1
requests-oauthlib            1.3.1
rsa                          4.9
setuptools                   63.4.1
six                          1.16.0
tensorboard                  2.9.1
tensorboard-data-server      0.6.1
tensorboard-plugin-wit       1.8.1
tensorflow                   2.9.0+git28e158bd.dtk22042
tensorflow-datasets          4.7.0
tensorflow-estimator         2.9.0
tensorflow-hub               0.12.0
tensorflow-io-gcs-filesystem 0.27.0
tensorflow-metadata          1.11.0
tensorflow-text              2.9.0
termcolor                    2.0.1
toml                         0.10.2
tqdm                         4.64.1
typing_extensions            4.4.0
urllib3                      1.26.12
Werkzeug                     2.2.2
wheel                        0.37.1
wrapt                        1.14.1
zipp                         3.10.0

环境配置完毕后,cd cd keras-cv/examples/training/classification/imagenet/, 进入训练脚本与训练代码所在位置
ln -s /public/DL_DATA/imagenet_tfrecord/trrecord imagenet, 创建数据集软链接
训练脚本为train_keras_cv.sh
训练脚本中的参数解释
--model_name :所训练的模型名称,所支持的网络名称位于keras-cv/keras_cv/models/__init__.py,注意目前经过DTK22.04训练验证的网络有ResNet101、ResNet50V2、RegNetX064、DarkNet53、DenseNet121、EfficientNetV2S,其余网络尚未训练验证。
--imagenet_path :数据集路径
--backup_path :模型更新路径
--weights_path :模型存储路径 注意:不要将该路径与--backup_path设置为同一个
--tensorboard_path : tensorboard日志存储路径
--use_xla=False :是否使用xla,当前版本使用xla存在错误,设置为False
--initial_learning_rate : 起始学习率
--learning_rate_schedule :学习率使用策略
--batch_size :batch_size设置
--epochs :训练迭代多少次

注意:进行断点续训时,第一个Epoch之后,训练日志输出的每个Epoch的step数量存在出入,可忽略,因为实际训练中每个Epoch的训练的step数量是不变的。此外已经经过训练验证的相关网络的数据记录在keras/examples/training/classification/imagenet/train_history.json文件中

```