gen_ocr_train_val_test.py 5.97 KB
Newer Older
1
2
3
4
5
6
# coding:utf8
import os
import shutil
import random
import argparse

7

–MrCuiHao's avatar
MrCuiHao committed
8
# 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹
9
10
def isCreateOrDeleteFolder(path, flag):
    flagPath = os.path.join(path, flag)
–MrCuiHao's avatar
MrCuiHao committed
11

12
13
    if os.path.exists(flagPath):
        shutil.rmtree(flagPath)
–MrCuiHao's avatar
MrCuiHao committed
14

15
16
17
18
19
    os.makedirs(flagPath)
    flagAbsPath = os.path.abspath(flagPath)
    return flagAbsPath


–MrCuiHao's avatar
MrCuiHao committed
20
21
def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
    # 按照指定的比例划分训练集、验证集、测试集
22
23
    labelPath = os.path.join(root, dir)
    labelAbsPath = os.path.abspath(labelPath)
–MrCuiHao's avatar
MrCuiHao committed
24

25
26
27
28
    if flag == "det":
        labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName)
    elif flag == "rec":
        labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName)
–MrCuiHao's avatar
MrCuiHao committed
29

30
31
32
33
    labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
    labelFileContent = labelFileRead.readlines()
    random.shuffle(labelFileContent)
    labelRecordLen = len(labelFileContent)
–MrCuiHao's avatar
MrCuiHao committed
34

35
36
37
38
    for index, labelRecordInfo in enumerate(labelFileContent):
        imageRelativePath = labelRecordInfo.split('\t')[0]
        imageLabel = labelRecordInfo.split('\t')[1]
        imageName = os.path.basename(imageRelativePath)
–MrCuiHao's avatar
MrCuiHao committed
39

40
41
42
43
        if flag == "det":
            imagePath = os.path.join(labelAbsPath, imageName)
        elif flag == "rec":
            imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
–MrCuiHao's avatar
MrCuiHao committed
44
45
46
47
48
49
50
51

        # 按预设的比例划分训练集、验证集、测试集
        trainValTestRatio = args.trainValTestRatio.split(":")
        trainRatio = eval(trainValTestRatio[0]) / 10
        valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
        curRatio = index / labelRecordLen

        if curRatio < trainRatio:
52
53
54
            imageCopyPath = os.path.join(absTrainRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
–MrCuiHao's avatar
MrCuiHao committed
55
        elif curRatio >= trainRatio and curRatio < valRatio:
56
57
58
            imageCopyPath = os.path.join(absValRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
–MrCuiHao's avatar
MrCuiHao committed
59
60
61
62
        else:
            imageCopyPath = os.path.join(absTestRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
63
64


65
66
67
68
69
70
# 删掉存在的文件
def removeFile(path):
    if os.path.exists(path):
        os.remove(path)


71
72
73
def genDetRecTrainVal(args):
    detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
    detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
–MrCuiHao's avatar
MrCuiHao committed
74
    detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test")
75
76
    recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
    recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
–MrCuiHao's avatar
MrCuiHao committed
77
78
    recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test")

79
80
    removeFile(os.path.join(args.detRootPath, "train.txt"))
    removeFile(os.path.join(args.detRootPath, "val.txt"))
–MrCuiHao's avatar
MrCuiHao committed
81
    removeFile(os.path.join(args.detRootPath, "test.txt"))
82
83
    removeFile(os.path.join(args.recRootPath, "train.txt"))
    removeFile(os.path.join(args.recRootPath, "val.txt"))
–MrCuiHao's avatar
MrCuiHao committed
84
85
    removeFile(os.path.join(args.recRootPath, "test.txt"))

86
87
    detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
    detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
–MrCuiHao's avatar
MrCuiHao committed
88
    detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
89
90
    recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
    recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
–MrCuiHao's avatar
MrCuiHao committed
91
92
    recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")

93
94
    for root, dirs, files in os.walk(args.labelRootPath):
        for dir in dirs:
–MrCuiHao's avatar
MrCuiHao committed
95
96
97
98
            splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
                          detTestTxt, "det")
            splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
                          recTestTxt, "rec")
99
100
101
102
        break


if __name__ == "__main__":
–MrCuiHao's avatar
MrCuiHao committed
103
    # 功能描述:分别划分检测和识别的训练集、验证集、测试集
104
    # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
–MrCuiHao's avatar
MrCuiHao committed
105
    # 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求
106
107
    parser = argparse.ArgumentParser()
    parser.add_argument(
–MrCuiHao's avatar
MrCuiHao committed
108
109
110
111
        "--trainValTestRatio",
        type=str,
        default="6:2:2",
        help="ratio of trainset:valset:testset")
112
113
114
    parser.add_argument(
        "--labelRootPath",
        type=str,
–MrCuiHao's avatar
MrCuiHao committed
115
        default="../train_data/label",
116
117
118
119
120
        help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
    )
    parser.add_argument(
        "--detRootPath",
        type=str,
–MrCuiHao's avatar
MrCuiHao committed
121
        default="../train_data/det",
122
123
124
125
        help="the path where the divided detection dataset is placed")
    parser.add_argument(
        "--recRootPath",
        type=str,
–MrCuiHao's avatar
MrCuiHao committed
126
        default="../train_data/rec",
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        help="the path where the divided recognition dataset is placed"
    )
    parser.add_argument(
        "--detLabelFileName",
        type=str,
        default="Label.txt",
        help="the name of the detection annotation file")
    parser.add_argument(
        "--recLabelFileName",
        type=str,
        default="rec_gt.txt",
        help="the name of the recognition annotation file"
    )
    parser.add_argument(
        "--recImageDirName",
        type=str,
        default="crop_img",
        help="the name of the folder where the cropped recognition dataset is located"
    )
    args = parser.parse_args()
    genDetRecTrainVal(args)