"vscode:/vscode.git/clone" did not exist on "99ec439da476c1a83ce29863395433833f0ac850"
eval.py 2.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gzip
import json

import torch
import typer
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from tqdm import tqdm
from transformers import LayoutLMv3ForTokenClassification

from helpers import (
    DataCollator,
    check_duplicate,
    MAX_LEN,
    parse_logits,
    prepare_inputs,
)

app = typer.Typer()

chen_cherry = SmoothingFunction()


@app.command()
def main(
    input_file: str = typer.Argument(..., help="input file"),
    model_path: str = typer.Argument(..., help="model path"),
    batch_size: int = typer.Option(16, help="batch size"),
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = (
        LayoutLMv3ForTokenClassification.from_pretrained(model_path, num_labels=MAX_LEN)
        .bfloat16()
        .to(device)
        .eval()
    )
    data_collator = DataCollator()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    datasets = []
    with gzip.open(input_file, "rt") as f:
        for line in tqdm(f):
            datasets.append(json.loads(line))
    # make batch faster
    datasets.sort(key=lambda x: len(x["source_boxes"]), reverse=True)

    total = 0
    total_out_idx = 0.0
    total_out_token = 0.0
    for i in tqdm(range(0, len(datasets), batch_size)):
        batch = datasets[i : i + batch_size]
        model_inputs = data_collator(batch)
        model_inputs = prepare_inputs(model_inputs, model)
        # forward
        with torch.no_grad():
            model_outputs = model(**model_inputs)
        logits = model_outputs.logits.cpu()
        for data, logit in zip(batch, logits):
            target_index = data["target_index"][:MAX_LEN]
            pred_index = parse_logits(logit, len(target_index))
            assert len(pred_index) == len(target_index)
            assert not check_duplicate(pred_index)
            target_texts = data["target_texts"][:MAX_LEN]
            source_texts = data["source_texts"][:MAX_LEN]
            pred_texts = []
            for idx in pred_index:
                pred_texts.append(source_texts[idx])
            total += 1
            total_out_idx += sentence_bleu(
                [target_index],
                [i + 1 for i in pred_index],
                smoothing_function=chen_cherry.method2,
            )
            total_out_token += sentence_bleu(
                [" ".join(target_texts).split()],
                " ".join(pred_texts).split(),
                smoothing_function=chen_cherry.method2,
            )

    print("total: ", total)
    print("out_idx: ", round(100 * total_out_idx / total, 1))
    print("out_token: ", round(100 * total_out_token / total, 1))


if __name__ == "__main__":
    app()