bench_hf.py 3.23 KB
Newer Older
1
2
3
4
import argparse

import torch
from data_utils import save_json
5
6
7
8
9
10
11
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
    process_result,
)
12
from tqdm import tqdm
Mick's avatar
Mick committed
13
from transformers import AutoModel, AutoProcessor, GenerationConfig
14
15
16
17
18


@torch.no_grad()
def eval_mmmu(args):
    eval_args = EvalArgs.from_cli_args(args)
Mick's avatar
Mick committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    try:
        from transformers import AutoModelForImageTextToText

        model = AutoModelForImageTextToText.from_pretrained(
            args.model_path,
            torch_dtype="auto",
            trust_remote_code=True,
        )
    except Exception as first_exception:
        try:
            model = AutoModel.from_pretrained(
                args.model_path,
                torch_dtype="auto",
                trust_remote_code=True,
                init_tts=False,
            )
        except Exception as second_exception:
            raise RuntimeError(
                f"Failed to load model: First attempt failed with {first_exception}, "
                f"second attempt failed with {second_exception}"
            ) from second_exception
40
41
42
43

    model = model.eval().cuda()

    processor = AutoProcessor.from_pretrained(
Mick's avatar
Mick committed
44
        args.model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
45
46
47
48
49
50
    )

    samples = prepare_samples(eval_args)
    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)
51
52
53
54
    generation_config = GenerationConfig(
        max_new_tokens=sampling_params["max_new_tokens"],
        do_sample=False,
    )
55
56
57
58
59
60
61

    answer_dict = {}
    for sample in tqdm(samples):
        prompt = sample["final_input_prompt"]
        image = sample["image"]
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        assert image is not None
        contents = []
        if prefix:
            contents += [{"type": "text", "text": prefix}]
        contents += [
            {
                "type": "image",
                "image": sample["image_path"],
            }
        ]
        if suffix:
            contents += [{"type": "text", "text": suffix}]
        messages = [{"role": "user", "content": contents}]
        model_inputs = processor.apply_chat_template(
            messages,
            tokenize=True,
            return_dict=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(model.device)
        input_len = model_inputs["input_ids"].shape[-1]
        generation = model.generate(**model_inputs, generation_config=generation_config)
        generation = generation[0][input_len:]
        response = processor.decode(generation, skip_special_tokens=True)
        print(f"response: {response}")
87
        process_result(response, sample, answer_dict, out_samples)
88
89
90

    args.output_path = f"{args.model_path}_val_hf.json"
    save_json(args.output_path, out_samples)
91
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
92
93
94
95
96
97
98
99
100
101
102
103
104
105


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
        required=True,
    )
    EvalArgs.add_cli_args(parser)
    args = parser.parse_args()

    eval_mmmu(args)