dbnet_model_eval_new.py 3.18 KB
Newer Older
xuxo's avatar
xuxo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from glob import glob
import os.path as osp
import platform

import numpy as np
import paddle
from tqdm import tqdm

from ppocr.data import build_dataloader
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
import tools.program as program


def model_eval(config, logger):
    config['Eval']['dataset']['data_dir'] = config['data_dir']
    label_file = osp.join(config['data_dir'],'test_icdar2015_label.txt')
    config['Eval']['dataset']['label_file_list'] = [str(label_file)]
    output_dir = config['output_dir']
    output_format = config['output_format'] if 'output_format' in config else 'bin'
    valid_formats = ['BIN', 'bin', 'NPY', 'npy']
    if output_format not in valid_formats:
        raise ValueError(
            f'Invalid file format of inference outputs: {output_format}, '
            f'expected: {"/".join(valid_formats)}')
    output_format = output_format.lower()

    valid_dataloader = build_dataloader(config, 'Eval', None, logger)
    max_iter = len(valid_dataloader) -1 if platform.system() == "Windows" \
                    else len(valid_dataloader)

    post_process_class = build_post_process(config['PostProcess'], config['Global'])
    eval_class = build_metric(config['Metric'])
    # output_list = glob(
    #     osp.join(output_dir, '**', f'img_*_0.{output_format}'), recursive=True)
    # if not output_list:
    #     raise RuntimeError('Cannot find inference output files.')

    # output_map = [None] * max_iter
    # for output_file in output_list:
    #     i = int(output_file.rsplit('_', 1)[-2].split('_', 1)[1])
    #     output_map[i] = output_file
    
    pbar = tqdm(total=len(valid_dataloader), desc='eval model', leave=True)
    for i, (batch, output_file) in enumerate(valid_dataloader):
        if i >= max_iter:
            break
        batch_numpy = [item.numpy() for item in batch]

        output_file = osp.join(output_dir, osp.splitext(osp.basename(output_file[0]))[0]+"_0.bin")
        # output_file = output_map[i]
        if output_format == 'bin':
            # output = np.fromfile(output_file, dtype=np.float16)
            # output = output.astype(np.float32).reshape(1, 1, 736, 1280)
            output = np.fromfile(output_file, dtype=np.float32).reshape(1, 1, 736, 1280)
        else:
            output = np.load(output_file).astype(np.float32)
        preds = {'maps': paddle.to_tensor(output)}
        post_result = post_process_class(preds, batch_numpy[1])
        eval_class(post_result, batch_numpy)
        pbar.update(1)

    logger.info('↓↓↓↓↓↓↓↓↓↓↓ Metrics ↓↓↓↓↓↓↓↓↓↓↓')
    metric = eval_class.get_metric()
    for k, v in metric.items():
        logger.info('{} = {}'.format(k, v))


def main():    
    config, _, logger, _ = program.preprocess()
    model_eval(config, logger)


if __name__ == '__main__':
    """
    Example for bash shell:

    >>> python3 model_eval.py \
    >>> -c PaddleOCR/configs/det/det_mv3_db.yml \
    >>> -o Global.use_gpu=False \
    >>>    data_dir=./icdar2015/text_localization/ \
    >>>    output_dir=./om_outputs/ \
    >>>    output_format=NPY \
    """
    main()