README.md 955 Bytes
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

Introduction:
            
Keras-CV源码及训练代码示例,TensorFlow版本为2.9,python版本支持3.7、3.8、3.9.
支持的网络有 ResNet101、ResNet50V2、RegNetX064、DarkNet53、DenseNet121、EfficientNetV2S
网络模型代码位于keras-cv/models.训练代码basic_training.py位于examples/training/classification/imagenet/
使用tfrecord格式的Imagenet数据集.相关网络的训练情况及参数设置位于examples/training/classification/imagenet/
rain_history.json

Install

安装DTK版本的TensorFlow2.9 pip install tensorflow-2.9.0+git28e158bd.dtk22042-cp37-cp37m-linux_x86_64.whl
执行python setup.py build
   python setup.py install
       
Train   
训练脚本位于examples/training/classification/imagenet/trian_keras_cv.sh 注意basic_training.py中的分布式设置为
默认使用所有检测到的加速卡设备,若要指定几张加速卡参与训练使用export HIP_VISIBLE_DEVICES=x1,x2,x3来指定.