"Speech2S/speech2s/scripts/shard_docs.py" did not exist on "417b607b2a622da9321c932a5b3bc0f6b0ece56b"
bench_hf.py 5.27 KB
Newer Older
1
2
import argparse

3
import PIL
4
5
import torch
from data_utils import save_json
6
7
8
9
10
11
12
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
    process_result,
)
13
from tqdm import tqdm
Mick's avatar
Mick committed
14
from transformers import AutoModel, AutoProcessor, GenerationConfig
15
16
17
18
19


@torch.no_grad()
def eval_mmmu(args):
    eval_args = EvalArgs.from_cli_args(args)
20
21
22
23
24
25
26

    sampling_params = get_sampling_params(eval_args)
    generation_config = GenerationConfig(
        max_new_tokens=sampling_params["max_new_tokens"],
        do_sample=False,
    )

Mick's avatar
Mick committed
27
28
29
30
31
32
33
34
35
36
    try:
        from transformers import AutoModelForImageTextToText

        model = AutoModelForImageTextToText.from_pretrained(
            args.model_path,
            torch_dtype="auto",
            trust_remote_code=True,
        )
    except Exception as first_exception:
        try:
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
            # check if the model is belongs to internvl
            if "InternVL" in args.model_path:
                from internvl_utils import load_image
                from transformers import AutoTokenizer

                tokenizer = AutoTokenizer.from_pretrained(args.model_path)
                model = AutoModel.from_pretrained(
                    args.model_path,
                    torch_dtype="auto",
                    trust_remote_code=True,
                )
                generation_config_internvl = dict(
                    max_new_tokens=sampling_params["max_new_tokens"], do_sample=False
                )

            else:
                model = AutoModel.from_pretrained(
                    args.model_path,
                    torch_dtype="auto",
                    trust_remote_code=True,
                    init_tts=False,
                )
Mick's avatar
Mick committed
59
60
61
62
63
        except Exception as second_exception:
            raise RuntimeError(
                f"Failed to load model: First attempt failed with {first_exception}, "
                f"second attempt failed with {second_exception}"
            ) from second_exception
64
65
66
67

    model = model.eval().cuda()

    processor = AutoProcessor.from_pretrained(
Mick's avatar
Mick committed
68
        args.model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
69
70
71
72
73
74
75
76
77
78
79
    )

    samples = prepare_samples(eval_args)
    out_samples = dict()

    answer_dict = {}
    for sample in tqdm(samples):
        prompt = sample["final_input_prompt"]
        image = sample["image"]
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
80
        assert image is not None
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

        if "InternVL" in args.model_path:
            pixel_values = load_image(sample["image_path"]).to(torch.bfloat16).cuda()
            contents = ""
            if prefix:
                contents += prefix
            contents += "<image>\n"
            if suffix:
                contents += suffix
            response = model.chat(
                tokenizer, pixel_values, contents, generation_config_internvl
            )
            print(f"response: {response}")
            process_result(response, sample, answer_dict, out_samples)
            continue

97
98
99
100
101
102
103
104
105
106
107
108
        contents = []
        if prefix:
            contents += [{"type": "text", "text": prefix}]
        contents += [
            {
                "type": "image",
                "image": sample["image_path"],
            }
        ]
        if suffix:
            contents += [{"type": "text", "text": suffix}]
        messages = [{"role": "user", "content": contents}]
Mick's avatar
Mick committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        try:
            model_inputs = processor.tokenizer.apply_chat_template(
                messages,
                tokenize=True,
                return_dict=True,
                add_generation_prompt=True,
                return_tensors="pt",
            ).to(model.device)
            input_len = model_inputs["input_ids"].shape[-1]
            generation = model.generate(
                **model_inputs, generation_config=generation_config
            )
            generation = generation[0][input_len:]
            response = processor.decode(generation, skip_special_tokens=True)
        except:
            contents = []
            if prefix:
                contents += [prefix]
            image = PIL.Image.open(sample["image_path"])
            contents += [image]
            if suffix:
                contents += [suffix]
            messages = [{"role": "user", "content": contents}]
            response = model.chat(
                msgs=messages,
                tokenizer=processor.tokenizer,
                sampling=False,
                max_new_tokens=sampling_params["max_new_tokens"],
                use_tts_template=False,
                generate_audio=False,
                temperature=0.0,
            )
141
        print(f"response: {response}")
142
        process_result(response, sample, answer_dict, out_samples)
143

144
    args.output_path = f"{args.model_path}_answer_hf.json"
145
    save_json(args.output_path, out_samples)
146
147
148
149
150
    eval_result(
        model_answer_path=args.output_path,
        answer_dict=answer_dict,
        eval_output_path=f"{args.model_path}_val_hf.json",
    )
151
152
153
154
155
156
157
158
159
160
161
162
163
164


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
        required=True,
    )
    EvalArgs.add_cli_args(parser)
    args = parser.parse_args()

    eval_mmmu(args)