docowl_doclocal4k_evaluate.py 3.07 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import jsonlines
from docowl_infer import DocOwlInfer
from tqdm import tqdm
import os
from icecream import ic
from evaluation.benchmarks_eval import llm_text_localization_eval
import argparse

def read_jsonl(filename):
    lines = []
    with open(filename, 'r', encoding='utf-8') as f:
        for line in jsonlines.Reader(f):
            lines.append(line)
    return lines


def save_jsonl(data, filename, print_log=True):
    """data is a list"""
    with open(filename, "w") as f:
        f.write("\n".join([json.dumps(e, ensure_ascii=False) for e in data]))
        
    if print_log:
        print('save %d samples to %s' % (len(data), filename))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='docowl1.5 doclocal4k evaluation')
    parser.add_argument('--model_path', type=str, help='the directory path of model')
    parser.add_argument('--task', type=str, choices=['text_grounding', 'text_recognition'])
    parser.add_argument('--doclocal4k_dir', type=str, help='the directory path of DocLocal4K')
    parser.add_argument('--save_dir', type=str, help='the directory to save predictions of the model')
    args = parser.parse_args()

    model_path = args.model_path
    task = args.task
    doclocal4k_dir = args.doclocal4k_dir
    save_dir = args.save_dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    test_path = os.path.join(doclocal4k_dir, task+'.jsonl')
    save_path = os.path.join(save_dir, task+'_test_pred.jsonl')

    if os.path.exists(save_path):
        print(save_path+' exists, skip inference. ')
    else:
        docowl = DocOwlInfer(ckpt_path=model_path, anchors='grid_9', add_global_img=False)
        print('load model from ', model_path)
        # infer the test samples one by one
        test_samples = read_jsonl(test_path)
        infer_results = []
        for sample in tqdm(test_samples):
            image =os.path.join(doclocal4k_dir, sample['image'][0])
            assert os.path.exists(image)
            question = sample['messages'][0]
            answer = sample['messages'][1]
            assert question['role'] == 'user'
            assert answer['role'] == 'assistant'
            query = question['content'].replace('<|image|>', '')
            gt_answer = answer['content']
            model_answer = docowl.inference(image, query)
            sample['model_answer'] = model_answer
            sample['gt_answer'] = gt_answer
            ic(model_answer, gt_answer)
            infer_results.append(sample)
        
        
        save_jsonl(infer_results, save_path)
    
    # calculate metrics
    pred_path = save_path

    if not os.path.exists(pred_path):
        print('not exists:', pred_path)
        exit(0)
    
    if task == 'text_recognition':
        llm_text_localization_eval(metric_names=['BLEU1', 'BLEU2', 'BLEU3', 'BLEU4'], result_path=pred_path, save_each_eval=True)
    elif task == 'text_grounding':
        llm_text_localization_eval(metric_names=['IOU@0.5'], result_path=pred_path, save_each_eval=True)
        
    print('==============================================')