bench_hf.py 2.62 KB
Newer Older
1
2
3
4
import argparse

import torch
from data_utils import save_json
5
6
7
8
9
10
11
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
    process_result,
)
12
from tqdm import tqdm
13
from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


@torch.no_grad()
def eval_mmmu(args):
    eval_args = EvalArgs.from_cli_args(args)

    model = AutoModelForImageTextToText.from_pretrained(
        args.model_path,
        torch_dtype="auto",
        trust_remote_code=True,
    )
    model = model.eval().cuda()

    processor = AutoProcessor.from_pretrained(
        args.model_path, torch_dtype="auto", device_map="auto"
    )

    samples = prepare_samples(eval_args)
    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)
35
36
37
38
    generation_config = GenerationConfig(
        max_new_tokens=sampling_params["max_new_tokens"],
        do_sample=False,
    )
39
40
41
42
43
44
45

    answer_dict = {}
    for sample in tqdm(samples):
        prompt = sample["final_input_prompt"]
        image = sample["image"]
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        assert image is not None
        contents = []
        if prefix:
            contents += [{"type": "text", "text": prefix}]
        contents += [
            {
                "type": "image",
                "image": sample["image_path"],
            }
        ]
        if suffix:
            contents += [{"type": "text", "text": suffix}]
        messages = [{"role": "user", "content": contents}]
        model_inputs = processor.apply_chat_template(
            messages,
            tokenize=True,
            return_dict=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(model.device)
        input_len = model_inputs["input_ids"].shape[-1]
        generation = model.generate(**model_inputs, generation_config=generation_config)
        generation = generation[0][input_len:]
        response = processor.decode(generation, skip_special_tokens=True)
        print(f"response: {response}")
71
        process_result(response, sample, answer_dict, out_samples)
72
73
74

    args.output_path = f"{args.model_path}_val_hf.json"
    save_json(args.output_path, out_samples)
75
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
76
77
78
79
80
81
82
83
84
85
86
87
88
89


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
        required=True,
    )
    EvalArgs.add_cli_args(parser)
    args = parser.parse_args()

    eval_mmmu(args)