eval_table.py 2.21 KB
Newer Older
WenmuZhou's avatar
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
WenmuZhou's avatar
WenmuZhou committed
18
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
WenmuZhou's avatar
WenmuZhou committed
19
20
21
22
23

import cv2
import json
from tqdm import tqdm
from ppstructure.table.table_metric import TEDS
WenmuZhou's avatar
WenmuZhou committed
24
from ppstructure.table.predict_table import TableSystem
WenmuZhou's avatar
WenmuZhou committed
25
from ppstructure.utility import init_args
WenmuZhou's avatar
WenmuZhou committed
26
27


WenmuZhou's avatar
WenmuZhou committed
28
29
30
31
32
def parse_args():
    parser = init_args()
    parser.add_argument("--gt_path", type=str)
    return parser.parse_args()

WenmuZhou's avatar
WenmuZhou committed
33
34
35
36
37
38
39
40
def main(gt_path, img_root, args):
    teds = TEDS(n_jobs=16)

    text_sys = TableSystem(args)
    jsons_gt = json.load(open(gt_path))  # gt
    pred_htmls = []
    gt_htmls = []
    for img_name in tqdm(jsons_gt):
WenmuZhou's avatar
WenmuZhou committed
41
        # read image
WenmuZhou's avatar
WenmuZhou committed
42
43
44
45
46
        img = cv2.imread(os.path.join(img_root,img_name))
        pred_html = text_sys(img)
        pred_htmls.append(pred_html)

        gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
WenmuZhou's avatar
WenmuZhou committed
47
        gt_html, gt = get_gt_html(gt_structures, contents_with_block)
WenmuZhou's avatar
WenmuZhou committed
48
        gt_htmls.append(gt_html)
WenmuZhou's avatar
WenmuZhou committed
49
    scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
WenmuZhou's avatar
WenmuZhou committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    print('teds:', sum(scores) / len(scores))


def get_gt_html(gt_structures, contents_with_block):
    end_html = []
    td_index = 0
    for tag in gt_structures:
        if '</td>' in tag:
            if contents_with_block[td_index] != []:
                end_html.extend(contents_with_block[td_index])
            end_html.append(tag)
            td_index += 1
        else:
            end_html.append(tag)
    return ''.join(end_html), end_html


if __name__ == '__main__':
WenmuZhou's avatar
WenmuZhou committed
68
    args = parse_args()
WenmuZhou's avatar
WenmuZhou committed
69
    main(args.gt_path,args.image_dir, args)