"examples/trials/vscode:/vscode.git/clone" did not exist on "b794076ecb38ce7b2aba9efac009c8e868d1b5cd"
Commit c0f5652d authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'main' into 'main'

Main

See merge request dcutoolkit/deeplearing/dlexamples_new!60
parents d3910de2 500760f3
...@@ -32,7 +32,7 @@ train_percent = 0.9 ...@@ -32,7 +32,7 @@ train_percent = 0.9
#-------------------------------------------------------# #-------------------------------------------------------#
VOCdevkit_path = 'VOCdevkit' VOCdevkit_path = 'VOCdevkit'
VOCdevkit_sets = [('2012', 'train'), ('2012', 'val')] VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
classes, _ = get_classes(classes_path) classes, _ = get_classes(classes_path)
def convert_annotation(year, image_id, list_file): def convert_annotation(year, image_id, list_file):
...@@ -56,8 +56,8 @@ if __name__ == "__main__": ...@@ -56,8 +56,8 @@ if __name__ == "__main__":
random.seed(0) random.seed(0)
if annotation_mode == 0 or annotation_mode == 1: if annotation_mode == 0 or annotation_mode == 1:
print("Generate txt in ImageSets.") print("Generate txt in ImageSets.")
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2012/Annotations') xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2012/ImageSets/Main') saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
temp_xml = os.listdir(xmlfilepath) temp_xml = os.listdir(xmlfilepath)
total_xml = [] total_xml = []
for xml in temp_xml: for xml in temp_xml:
...@@ -106,4 +106,4 @@ if __name__ == "__main__": ...@@ -106,4 +106,4 @@ if __name__ == "__main__":
convert_annotation(year, image_id, list_file) convert_annotation(year, image_id, list_file)
list_file.write('\n') list_file.write('\n')
list_file.close() list_file.close()
print("Generate 2012_train.txt and 2012_val.txt for train done.") print("Generate 2007_train.txt and 2007_val.txt for train done.")
...@@ -94,7 +94,7 @@ python tools/visualize_dataset.py --classes=./data/voc2012.names ...@@ -94,7 +94,7 @@ python tools/visualize_dataset.py --classes=./data/voc2012.names
``` ```
export HIP_VISIBLE_DEVICES=0 export HIP_VISIBLE_DEVICES=0
export PYTHONPATH=/public/home/zhenyi/miniconda3/envs/tf2.7.0-dtk21.10-build/bin/ export PYTHONPATH=/public/home/zhenyi/miniconda3/envs/tf2.7.0-dtk21.10-build/bin/
#without xla
python train.py \ python train.py \
--dataset ./data/voc2012_train.tfrecord \ --dataset ./data/voc2012_train.tfrecord \
--val_dataset ./data/voc2012_val.tfrecord \ --val_dataset ./data/voc2012_val.tfrecord \
...@@ -105,5 +105,16 @@ python train.py \ ...@@ -105,5 +105,16 @@ python train.py \
--epochs 10 \ --epochs 10 \
--weights ./checkpoints/yolov3.tf \ --weights ./checkpoints/yolov3.tf \
--weights_num_classes 80 --weights_num_classes 80
#with xla
TF_XLA_FLAGS="--tf_xla_auto_jit=2" python train.py \
--dataset ./data/voc2012_train.tfrecord \
--val_dataset ./data/voc2012_val.tfrecord \
--classes ./data/voc2012.names \
--num_classes 20 \
--mode fit --transfer darknet \
--batch_size 16 \
--epochs 10 \
--weights ./checkpoints/yolov3.tf \
--weights_num_classes 80
``` ```
...@@ -17,8 +17,16 @@ ...@@ -17,8 +17,16 @@
- `export MIOPEN_FIND_MODE=3` - `export MIOPEN_FIND_MODE=3`
## 测试运行 ## 测试运行
#whitout xla
` python3 main.py ` ` python3 main.py `
#with xla
TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 main.py
## 参数说明 ## 参数说明
main.py文件内: main.py文件内:
- trainGenerator的第一个参数为batch_size - trainGenerator的第一个参数为batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment