test.py 3.15 KB
Newer Older
Geewook Kim's avatar
Geewook Kim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
Donut
Copyright (c) 2022-present NAVER Corp.
MIT License
"""
import argparse
import json
import os
import re
from pathlib import Path

import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm

from donut import DonutModel, JSONParseEvaluator, load_json, save_json


def test(args):
Geewook Kim's avatar
Geewook Kim committed
22
    pretrained_model = DonutModel.from_pretrained(args.pretrained_model_name_or_path)
Geewook Kim's avatar
Geewook Kim committed
23
24
25
26
27
28
29
30
31
32
33
34

    if torch.cuda.is_available():
        pretrained_model.half()
        pretrained_model.to("cuda")
    else:
        pretrained_model.encoder.to(torch.bfloat16)

    pretrained_model.eval()

    if args.save_path:
        os.makedirs(os.path.dirname(args.save_path), exist_ok=True)

35
36
    predictions = []
    ground_truths = []
Geewook Kim's avatar
Geewook Kim committed
37
38
    accs = []

39
    evaluator = JSONParseEvaluator()
Geewook Kim's avatar
Geewook Kim committed
40
41
42
43
44
45
46
47
    dataset = load_dataset(args.dataset_name_or_path, split=args.split)

    for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
        ground_truth = json.loads(sample["ground_truth"])

        if args.task_name == "docvqa":
            output = pretrained_model.inference(
                image=sample["image"],
Geewook Kim's avatar
Geewook Kim committed
48
                prompt=f"<s_{args.task_name}><s_question>{ground_truth['gt_parses'][0]['question'].lower()}</s_question><s_answer>",
Geewook Kim's avatar
Geewook Kim committed
49
50
51
52
53
54
55
56
            )["predictions"][0]
        else:
            output = pretrained_model.inference(image=sample["image"], prompt=f"<s_{args.task_name}>")["predictions"][0]

        if args.task_name == "rvlcdip":
            gt = ground_truth["gt_parse"]
            score = float(output["class"] == gt["class"])
        elif args.task_name == "docvqa":
57
58
59
60
61
            # Note: we evaluated the model on the official website.
            # In this script, an exact-match based score will be returned instead
            gt = ground_truth["gt_parses"]
            answers = set([qa_parse["answer"] for qa_parse in gt])
            score = float(output["answer"] in answers)
Geewook Kim's avatar
Geewook Kim committed
62
63
64
65
66
67
        else:
            gt = ground_truth["gt_parse"]
            score = evaluator.cal_acc(output, gt)

        accs.append(score)

68
69
        predictions.append(output)
        ground_truths.append(gt)
Geewook Kim's avatar
Geewook Kim committed
70

71
72
73
74
75
76
77
78
    scores = {
        "ted_accuracies": accs,
        "ted_accuracy": np.mean(accs),
        "f1_accuracy": evaluator.cal_f1(predictions, ground_truths),
    }
    print(
        f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}"
    )
Geewook Kim's avatar
Geewook Kim committed
79
80

    if args.save_path:
81
82
        scores["predictions"] = predictions
        scores["ground_truths"] = ground_truths
Geewook Kim's avatar
Geewook Kim committed
83
84
        save_json(args.save_path, scores)

85
    return predictions
Geewook Kim's avatar
Geewook Kim committed
86
87
88
89


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
Geewook Kim's avatar
Geewook Kim committed
90
    parser.add_argument("--pretrained_model_name_or_path", type=str)
Geewook Kim's avatar
Geewook Kim committed
91
92
93
94
95
96
97
98
99
    parser.add_argument("--dataset_name_or_path", type=str)
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--task_name", type=str, default=None)
    parser.add_argument("--save_path", type=str, default=None)
    args, left_argv = parser.parse_known_args()

    if args.task_name is None:
        args.task_name = os.path.basename(args.dataset_name_or_path)

100
    predictions = test(args)