vllm_vlms.py 11.4 KB
Newer Older
1
import copy
Lintang Sutawika's avatar
Lintang Sutawika committed
2
import logging
3
4
5
6
7
8
9
10
from typing import Dict, List, Optional

import transformers
from more_itertools import distribute
from tqdm import tqdm

from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
11
12
13
14
from lm_eval.models.utils import (
    Collator,
    handle_stop_sequences,
    replace_placeholders,
15
    resize_image,
16
17
    undistribute,
)
18
from lm_eval.models.vllm_causallms import VLLM
Lintang Sutawika's avatar
Lintang Sutawika committed
19
20
21


eval_logger = logging.getLogger(__name__)
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


try:
    import ray
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest  # noqa: F401
    from vllm.transformers_utils.tokenizer import get_tokenizer  # noqa: F401
except ModuleNotFoundError:
    pass


DEFAULT_IMAGE_PLACEHOLDER = "<image>"


@register_model("vllm-vlm")
class VLLM_VLM(VLLM):
    MULTIMODAL = True

    def __init__(
        self,
        pretrained: str,
        trust_remote_code: Optional[bool] = False,
        revision: Optional[str] = None,
        interleave: bool = True,
        # TODO<baber>: handle max_images and limit_mm_per_prompt better
        max_images: int = 999,
48
49
50
        image_width: Optional[int] = None,
        image_height: Optional[int] = None,
        image_max_side: Optional[int] = None,
51
52
        **kwargs,
    ):
53
54
55
56
57
58
59
60
61
        self.image_width = image_width
        self.image_height = image_height
        self.image_max_side = image_max_side
        if self.image_max_side and (self.image_width or self.image_height):
            raise ValueError(
                "Ambiguous config for image resize: you can not specify both "
                "image_max_side and (image_width or image_height)"
            )

62
63
64
        if max_images != 999:
            kwargs["limit_mm_per_prompt"] = {"image": max_images}
            eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}")
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        super().__init__(
            pretrained=pretrained,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )
        self.interleave = interleave
        self.max_images = max_images
        self.processor = transformers.AutoProcessor.from_pretrained(
            pretrained,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )
        self.chat_applied: bool = False

    def tok_batch_multimodal_encode(
        self,
        strings: List[str],  # note that input signature of this fn is different
        images,  # TODO: typehint on this
        left_truncate_len: int = None,
        truncation: bool = False,
    ):
        images = [img[: self.max_images] for img in images]
88
89
90
91
92
93
94
95
96
97
98
        # TODO<baber>: is the default placeholder always <image>?
        if self.chat_applied is False:
            strings = [
                replace_placeholders(
                    string,
                    DEFAULT_IMAGE_PLACEHOLDER,
                    DEFAULT_IMAGE_PLACEHOLDER,
                    self.max_images,
                )
                for string in strings
            ]
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        outputs = []
        for x, i in zip(strings, images):
            inputs = {
                "prompt": x,
                "multi_modal_data": {"image": i},
            }
            outputs.append(inputs)
        return outputs

    def _model_generate(
        self,
        requests: List[List[dict]] = None,
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
        if generate:
            kwargs = self.modify_gen_kwargs(kwargs)
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
        else:
            sampling_params = SamplingParams(
                temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
            )
        if self.data_parallel_size > 1:
Baber Abbasi's avatar
Baber Abbasi committed
125
            # vLLM hangs if resources are set in ray.remote
126
127
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
Baber Abbasi's avatar
Baber Abbasi committed
128
            @ray.remote
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            def run_inference_one_model(
                model_args: dict, sampling_params, requests: List[List[dict]]
            ):
                llm = LLM(**model_args)
                return llm.generate(requests, sampling_params=sampling_params)

            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
            inputs = ((self.model_args, sampling_params, req) for req in requests)
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
            # flatten results
            return undistribute(results)

        if self.lora_request is not None:
            outputs = self.model.generate(
                requests,
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
                lora_request=self.lora_request,
            )
        else:
            outputs = self.model.generate(
                requests,
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
            )
        return outputs

Baber Abbasi's avatar
Baber Abbasi committed
161
162
163
    def apply_chat_template(
        self, chat_history: List[Dict[str, str]], add_generation_prompt=True
    ) -> str:
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        self.chat_applied = True
        if not self.interleave:
            for content in chat_history:
                c = []
                text = content["content"]

                # Count and remove image placeholders
                image_count = min(
                    self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
                )
                text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")

                # Add image entries
                for _ in range(image_count):
                    c.append({"type": "image", "image": None})

                # Add single text entry at the end
                c.append({"type": "text", "text": text})

                content["content"] = c
        else:
            for content in chat_history:
                c = []
                text = content["content"]
                expected_image_count = min(
                    self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
                )
                actual_image_count = 0

                text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)

                for i, part in enumerate(text_parts):
                    # TODO: concatenate text parts (esp. if skipping images)?
                    if part:  # Add non-empty text parts
                        c.append({"type": "text", "text": part})
                    if (
                        (i < len(text_parts) - 1) and i < self.max_images
                    ):  # Add image placeholder after each split except the last
                        c.append({"type": "image"})
                        actual_image_count += 1

                content["content"] = c

                if actual_image_count != expected_image_count:
                    raise ValueError(
                        f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
                    )

        return self.processor.apply_chat_template(
Baber Abbasi's avatar
Baber Abbasi committed
213
214
215
            chat_history,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=not add_generation_prompt,
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        )

    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
        # TODO: support text-only reqs
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            toks = self.tok_encode(x[0])
            return -len(toks), x[0]

        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running generate_until requests with text+image input",
        )
        # TODO: port auto-batch sizing into this.

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        re_ords = Collator(
            [reg.args for reg in requests],
            _collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
        chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
251
        eos = self.tokenizer.decode(self.eot_token_id)
252
253
254
        for chunk in chunks:
            contexts, all_gen_kwargs, aux_arguments = zip(*chunk)

255
256
257
258
259
260
261
262
263
            visuals = [
                [
                    resize_image(
                        img, self.image_width, self.image_height, self.image_max_side
                    )
                    for img in arg["visual"]
                ]
                for arg in aux_arguments
            ]
264
265
266
267
268
269
270
271
272
273
274
275
276

            if not isinstance(contexts, list):
                contexts = list(
                    contexts
                )  # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
                # TODO: could we upstream this workaround to HF?

            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
277
278
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
                )
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            max_ctx_len = self.max_length - max_gen_toks

            inputs = self.tok_batch_multimodal_encode(
                contexts,
                visuals,
                left_truncate_len=max_ctx_len,
            )

296
297
298
            cont = self._model_generate(
                inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs
            )
299
300
301
302
303
304
305
306
307
308
309
310
311

            for output, context in zip(cont, contexts):
                generated_text = output.outputs[0].text
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)

        pbar.close()
        return res