docowl_benchmark_evaluate.py 3.9 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import json
import jsonlines
from docowl_infer import DocOwlInfer
from tqdm import tqdm
import os
from icecream import ic
from evaluation.benchmarks_eval import (llm_text_localization_eval, llm_textcaps_textvqa_eval,llm_benchmark_eval)
from evaluation.due_benchmarks_eval import llm_duebenchmark_eval
import argparse

def read_jsonl(filename):
    lines = []
    with open(filename, 'r', encoding='utf-8') as f:
        for line in jsonlines.Reader(f):
            lines.append(line)
    return lines


def save_jsonl(data, filename, print_log=True):
    """data is a list"""
    with open(filename, "w") as f:
        f.write("\n".join([json.dumps(e, ensure_ascii=False) for e in data]))
        
    if print_log:
        print('save %d samples to %s' % (len(data), filename))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='docowl1.5 benchmark evaluation')
    parser.add_argument('--model_path', type=str, help='the directory path of model')
    parser.add_argument('--dataset', type=str, choices=['DocVQA', 'InfographicsVQA', 'WikiTableQuestions', 'DeepForm', 'KleisterCharity', 'TabFact',
                                                        'ChartQA', 'TextVQA', 'TextCaps', 'VisualMRC'])
    parser.add_argument('--downstream_dir', type=str, help='the directory path of DocDownstream-1.0')
    parser.add_argument('--save_dir', type=str, help='the directory to save predictions of the model')
    args = parser.parse_args()

    model_path = args.model_path
    dataset = args.dataset
    downstream_dir = args.downstream_dir
    save_dir = args.save_dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    test_path = os.path.join(downstream_dir, 'test', dataset+'_test.jsonl')
    save_path = os.path.join(save_dir, dataset+'_test_pred.jsonl')

    if os.path.exists(save_path):
        print(save_path+' exists, skip inference. ')
    else:
        docowl = DocOwlInfer(ckpt_path=model_path, anchors='grid_9', add_global_img=True)
        print('load model from ', model_path)
        # infer the test samples one by one
        test_samples = read_jsonl(test_path)
        infer_results = []
        for sample in tqdm(test_samples):
            image =os.path.join(downstream_dir, sample['image'][0])
            assert os.path.exists(image)
            question = sample['messages'][0]
            answer = sample['messages'][1]
            assert question['role'] == 'user'
            assert answer['role'] == 'assistant'
            query = question['content'].replace('<|image|>', '')
            gt_answer = answer['content']
            model_answer = docowl.inference(image, query)
            sample['model_answer'] = model_answer
            sample['gt_answer'] = gt_answer
            ic(model_answer, gt_answer)
            infer_results.append(sample)
        save_jsonl(infer_results, save_path)
    
    # calculate metrics
    pred_path = save_path

    if not os.path.exists(pred_path):
        print('not exists:', pred_path)
        exit(0)
    
    meta_dir = os.path.join(downstream_dir, 'meta')

    if dataset in ['DeepForm', 'DocVQA', 'InfographicsVQA', 'KleisterCharity', 'WikiTableQuestions']:
        llm_duebenchmark_eval(dataset_name=dataset, split='test', llm_pred_path=pred_path, meta_dir=meta_dir)
    elif dataset in ['TabFact']:
        llm_benchmark_eval(metric_names=['ExactAccuracy'], result_path=pred_path, save_each_eval=True)
    elif dataset in ['ChartQA']:
        llm_benchmark_eval(metric_names=['RelaxedAccuracy'], result_path=pred_path, save_each_eval=True)
    elif dataset in ['TextCaps', 'TextVQA']:
        llm_textcaps_textvqa_eval(result_path=pred_path, dataset=dataset, split='test', meta_dir=meta_dir)
    elif dataset in ['VisualMRC']:
        llm_benchmark_eval(metric_names=['BLEU1', 'BLEU2', 'BLEU3', 'BLEU4', 'Meteor', 'RougeL', 'CIDEr'], result_path=pred_path, save_each_eval=True)
    
    print('==============================================')