gen_ocr_train_val_test.py 5.99 KB
Newer Older
1
2
3
4
5
6
# coding:utf8
import os
import shutil
import random
import argparse

7

–MrCuiHao's avatar
MrCuiHao committed
8
# 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹
9
10
def isCreateOrDeleteFolder(path, flag):
    flagPath = os.path.join(path, flag)
–MrCuiHao's avatar
MrCuiHao committed
11

12
13
    if os.path.exists(flagPath):
        shutil.rmtree(flagPath)
–MrCuiHao's avatar
MrCuiHao committed
14

15
16
17
18
19
    os.makedirs(flagPath)
    flagAbsPath = os.path.abspath(flagPath)
    return flagAbsPath


20
def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
–MrCuiHao's avatar
MrCuiHao committed
21
    # 按照指定的比例划分训练集、验证集、测试集
22
    dataAbsPath = os.path.abspath(root)
–MrCuiHao's avatar
MrCuiHao committed
23

24
    if flag == "det":
25
        labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
26
    elif flag == "rec":
27
        labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)
–MrCuiHao's avatar
MrCuiHao committed
28

29
30
31
32
    labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
    labelFileContent = labelFileRead.readlines()
    random.shuffle(labelFileContent)
    labelRecordLen = len(labelFileContent)
–MrCuiHao's avatar
MrCuiHao committed
33

34
35
36
37
    for index, labelRecordInfo in enumerate(labelFileContent):
        imageRelativePath = labelRecordInfo.split('\t')[0]
        imageLabel = labelRecordInfo.split('\t')[1]
        imageName = os.path.basename(imageRelativePath)
–MrCuiHao's avatar
MrCuiHao committed
38

39
        if flag == "det":
40
            imagePath = os.path.join(dataAbsPath, imageName)
41
        elif flag == "rec":
42
            imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
–MrCuiHao's avatar
MrCuiHao committed
43
44
45
46
47
48
49
50

        # 按预设的比例划分训练集、验证集、测试集
        trainValTestRatio = args.trainValTestRatio.split(":")
        trainRatio = eval(trainValTestRatio[0]) / 10
        valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
        curRatio = index / labelRecordLen

        if curRatio < trainRatio:
51
52
53
            imageCopyPath = os.path.join(absTrainRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
–MrCuiHao's avatar
MrCuiHao committed
54
        elif curRatio >= trainRatio and curRatio < valRatio:
55
56
57
            imageCopyPath = os.path.join(absValRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
–MrCuiHao's avatar
MrCuiHao committed
58
59
60
61
        else:
            imageCopyPath = os.path.join(absTestRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
62
63


64
65
66
67
68
69
# 删掉存在的文件
def removeFile(path):
    if os.path.exists(path):
        os.remove(path)


70
71
72
def genDetRecTrainVal(args):
    detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
    detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
–MrCuiHao's avatar
MrCuiHao committed
73
    detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test")
74
75
    recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
    recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
–MrCuiHao's avatar
MrCuiHao committed
76
77
    recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test")

78
79
    removeFile(os.path.join(args.detRootPath, "train.txt"))
    removeFile(os.path.join(args.detRootPath, "val.txt"))
–MrCuiHao's avatar
MrCuiHao committed
80
    removeFile(os.path.join(args.detRootPath, "test.txt"))
81
82
    removeFile(os.path.join(args.recRootPath, "train.txt"))
    removeFile(os.path.join(args.recRootPath, "val.txt"))
–MrCuiHao's avatar
MrCuiHao committed
83
84
    removeFile(os.path.join(args.recRootPath, "test.txt"))

85
86
    detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
    detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
–MrCuiHao's avatar
MrCuiHao committed
87
    detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
88
89
    recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
    recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
–MrCuiHao's avatar
MrCuiHao committed
90
91
    recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")

92
93
94
95
    splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
                  detTestTxt, "det")

    for root, dirs, files in os.walk(args.datasetRootPath):
96
        for dir in dirs:
97
98
99
100
101
            if dir == 'crop_img':
                splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
                              recTestTxt, "rec")
            else:
                continue
102
103
104
        break


105

106
if __name__ == "__main__":
–MrCuiHao's avatar
MrCuiHao committed
107
    # 功能描述:分别划分检测和识别的训练集、验证集、测试集
108
    # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
–MrCuiHao's avatar
MrCuiHao committed
109
    # 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求
110
111
    parser = argparse.ArgumentParser()
    parser.add_argument(
–MrCuiHao's avatar
MrCuiHao committed
112
113
114
115
        "--trainValTestRatio",
        type=str,
        default="6:2:2",
        help="ratio of trainset:valset:testset")
116
    parser.add_argument(
117
        "--datasetRootPath",
118
        type=str,
119
        default="../train_data/",
120
121
122
123
124
        help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
    )
    parser.add_argument(
        "--detRootPath",
        type=str,
–MrCuiHao's avatar
MrCuiHao committed
125
        default="../train_data/det",
126
127
128
129
        help="the path where the divided detection dataset is placed")
    parser.add_argument(
        "--recRootPath",
        type=str,
–MrCuiHao's avatar
MrCuiHao committed
130
        default="../train_data/rec",
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        help="the path where the divided recognition dataset is placed"
    )
    parser.add_argument(
        "--detLabelFileName",
        type=str,
        default="Label.txt",
        help="the name of the detection annotation file")
    parser.add_argument(
        "--recLabelFileName",
        type=str,
        default="rec_gt.txt",
        help="the name of the recognition annotation file"
    )
    parser.add_argument(
        "--recImageDirName",
        type=str,
        default="crop_img",
        help="the name of the folder where the cropped recognition dataset is located"
    )
    args = parser.parse_args()
    genDetRecTrainVal(args)