"vscode:/vscode.git/clone" did not exist on "6d2d0ce2857a51e6136ae78205bd5848f26e5813"
create_dataset.py 3.08 KB
Newer Older
dengjb's avatar
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np


def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
    try:
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        imgH, imgW = img.shape[0], img.shape[1]
        if imgH * imgW == 0:
            return False
        return True
    except:
        return False


def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            #print(k)
            txn.put(k.encode(), v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.

    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    assert(len(imagePathList) == len(labelList))
dengjb's avatar
dengjb committed
40
41
    if not os.path.exists(outputPath):
        os.makedirs(outputPath)
dengjb's avatar
dengjb committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    nSamples = len(imagePathList)
    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 1
    #for i in range(100):
    for i in range(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue

        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i]).encode()
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt-1
    cache['num-samples'] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)

def parse_labels(path):
    labels, image_path = [], []
    with open(path,'r') as f:
        lines = f.readlines()
        for line in lines:
            image, label = line.strip("\n").split(' ')
            labels.append(label)
            image = "./90kDICT32px/" + image[1:]
            image_path.append(image)
    return labels, image_path

if __name__ == '__main__':
dengjb's avatar
dengjb committed
89
90
91
92
    output_path = "./synth90k"
    dataset_path = "90kDICT32px/"
    train_labels, train_images = parse_labels(dataset_path + "annotation_train.txt")
    val_labels,val_images = paese_labels(dataset_path + "annotation_val.txt")
dengjb's avatar
dengjb committed
93
94
95
96
   
    #print(train_labels)
    print("="*50)
    #print(train_images)
dengjb's avatar
dengjb committed
97
98
    createDataset(output_path+'/train',train_images,train_labels)
    createDataset(output_path+"/val",train_images,train_labels)
dengjb's avatar
dengjb committed
99