"docker/Dockerfile.gb200" did not exist on "f200af0d8cde04ad746c37f51a40f7e218b6b581"
vllm_causallms.py 15.7 KB
Newer Older
baberabb's avatar
baberabb committed
1
from collections import defaultdict
baberabb's avatar
baberabb committed
2
from typing import List, Tuple, Optional, Literal, Union, Any
baberabb's avatar
baberabb committed
3
from transformers import AutoTokenizer
baberabb's avatar
baberabb committed
4
5
6
7
8
9
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
import copy
from tqdm import tqdm
from lm_eval.api.registry import register_model
from lm_eval import utils
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
10

11
12
try:
    from vllm import LLM, SamplingParams
baberabb's avatar
baberabb committed
13
    from ray.util.multiprocessing import Pool
baberabb's avatar
baberabb committed
14
    from vllm.transformers_utils.tokenizer import get_tokenizer
15
16
except ModuleNotFoundError:
    pass
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
17

baberabb's avatar
baberabb committed
18
19
eval_logger = utils.eval_logger

baberabb's avatar
baberabb committed
20

baberabb's avatar
baberabb committed
21
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
baberabb's avatar
baberabb committed
22
def run_inference_one_model(model_args: dict, sampling_params, requests: List[int]):
baberabb's avatar
baberabb committed
23
24
    # gpu_id = [x for x in gpu_id]
    # os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id)
baberabb's avatar
baberabb committed
25
26
27
28
    llm = LLM(**model_args)
    return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)


baberabb's avatar
baberabb committed
29
30
31
32
33
34
35
36
37
38
@register_model("vllm")
class VLLM(LM):
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
        self,
        pretrained="gpt2",
        dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
        revision: Optional[str] = None,
        trust_remote_code: Optional[bool] = False,
baberabb's avatar
baberabb committed
39
        tokenizer: Optional[str] = None,
baberabb's avatar
baberabb committed
40
        tokenizer_mode: Literal["auto", "slow"] = "auto",
baberabb's avatar
baberabb committed
41
        tokenizer_revision: Optional[str] = None,
baberabb's avatar
baberabb committed
42
        tensor_parallel_size: int = 1,
baberabb's avatar
baberabb committed
43
        quantization: Optional[Literal["awq"]] = None,
baberabb's avatar
baberabb committed
44
45
        max_gen_toks: int = 256,
        swap_space: int = 4,
baberabb's avatar
baberabb committed
46
        batch_size: Union[str, int] = 1,
baberabb's avatar
baberabb committed
47
        max_batch_size=None,
baberabb's avatar
baberabb committed
48
        max_length: int = None,
baberabb's avatar
baberabb committed
49
        seed: int = 1234,
50
        gpu_memory_utilization: float = 0.9,
baberabb's avatar
baberabb committed
51
        device: str = "cuda",
52
        data_parallel_size: int = 1,
baberabb's avatar
baberabb committed
53
54
    ):
        super().__init__()
55
56
57

        try:
            import vllm
baberabb's avatar
baberabb committed
58
        except ModuleNotFoundError:
59
            raise Exception(
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
60
                "attempted to use 'vllm' LM type, but package `vllm` is not installed. \
61
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`",
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
62
63
            )

baberabb's avatar
baberabb committed
64
        assert "cuda" in device or device is None, "vLLM only supports CUDA"
baberabb's avatar
baberabb committed
65
        self.tensor_parallel_size = int(tensor_parallel_size)
66
        self.data_parallel_size = int(data_parallel_size)
baberabb's avatar
baberabb committed
67
68
69
70
71
        self.model_args = {
            "model": pretrained,
            "gpu_memory_utilization": float(gpu_memory_utilization),
            "revision": revision,
            "dtype": dtype,
baberabb's avatar
baberabb committed
72
            "tokenizer": tokenizer,
baberabb's avatar
baberabb committed
73
            "tokenizer_mode": tokenizer_mode,
baberabb's avatar
baberabb committed
74
            "tokenizer_revision": tokenizer_revision,
baberabb's avatar
baberabb committed
75
76
77
78
79
80
            "trust_remote_code": trust_remote_code,
            "tensor_parallel_size": int(tensor_parallel_size),
            "swap_space": int(swap_space),
            "quantization": quantization,
            "seed": int(seed),
        }
81
        if self.data_parallel_size <= 1:
baberabb's avatar
baberabb committed
82
            self.model = LLM(**self.model_args)
baberabb's avatar
baberabb committed
83
84
        else:
            self.model_args["worker_use_ray"] = True
baberabb's avatar
nits  
baberabb committed
85
86
87
88
89
90
        self.tokenizer = get_tokenizer(
            tokenizer if tokenizer else pretrained,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
        )
baberabb's avatar
baberabb committed
91
92
93
94
95
96
97
98
99
100
101
102
103
        self.batch_size = batch_size
        self._max_length = max_length
        self._max_gen_toks = max_gen_toks

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
baberabb's avatar
baberabb committed
104
105
        if hasattr(self.tokenizer, "model_max_length"):
            return self.tokenizer.model_max_length
baberabb's avatar
baberabb committed
106
107
108
109
110
111
        return self._DEFAULT_MAX_LENGTH

    @property
    def max_gen_toks(self):
        return self._max_gen_toks

baberabb's avatar
baberabb committed
112
113
114
115
116
117
118
    def tok_encode(
        self,
        string: str,
        left_truncate_len=None,
        add_special_tokens=False,
        truncation=False,
    ):
baberabb's avatar
baberabb committed
119
        """ """
baberabb's avatar
baberabb committed
120
121
122
        encoding = self.tokenizer.encode(
            string, add_special_tokens=add_special_tokens, truncation=truncation
        )
baberabb's avatar
baberabb committed
123
124
125
126
127
128
129
130
131

        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]

        return encoding

    def _model_generate(
        self,
baberabb's avatar
baberabb committed
132
        requests: List[List[int]] = None,
baberabb's avatar
baberabb committed
133
134
135
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
baberabb's avatar
baberabb committed
136
        use_tqdm=True,
baberabb's avatar
baberabb committed
137
138
        **kwargs,
    ):
baberabb's avatar
bugfix  
baberabb committed
139
140
        if "do_sample" in kwargs.keys():
            kwargs.pop("do_sample")
baberabb's avatar
baberabb committed
141
        if generate:
142
143
144
145
146
            # hf defaults
            kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
            kwargs["spaces_between_special_tokens"] = kwargs.get(
                "spaces_between_special_tokens", False
            )
baberabb's avatar
baberabb committed
147
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
baberabb's avatar
baberabb committed
148
        else:
baberabb's avatar
baberabb committed
149
            sampling_params = SamplingParams(
baberabb's avatar
baberabb committed
150
151
                temperature=0, prompt_logprobs=2, max_tokens=1
            )
152
        if self.data_parallel_size > 1:
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
153
154
155
            requests = [
                list(x) for x in utils.divide(requests, self.data_parallel_size)
            ]
baberabb's avatar
baberabb committed
156
            inputs = [(self.model_args, sampling_params, req) for req in requests]
baberabb's avatar
baberabb committed
157

158
            with Pool(self.data_parallel_size) as pool:
baberabb's avatar
baberabb committed
159
                results = pool.starmap(run_inference_one_model, inputs)
baberabb's avatar
baberabb committed
160
161
162
163
164
165
166
167
168
            # flatten results
            return [item for sublist in results for item in sublist]

        outputs = self.model.generate(
            prompt_token_ids=requests,
            sampling_params=sampling_params,
            use_tqdm=use_tqdm,
        )

baberabb's avatar
baberabb committed
169
170
        return outputs

baberabb's avatar
baberabb committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]

        whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
        context_enc = self.tok_encode(context, add_special_tokens=False)

        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

baberabb's avatar
baberabb committed
186
    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
187
188
189
190
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
baberabb's avatar
baberabb committed
191
192
193
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
baberabb's avatar
baberabb committed
194
            else:
baberabb's avatar
baberabb committed
195
                context_enc, continuation_enc = self._encode_pair(context, continuation)
baberabb's avatar
baberabb committed
196
197
198
199
200

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

baberabb's avatar
baberabb committed
201
    def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
baberabb's avatar
baberabb committed
202
203
204
205
206
207
208
209
210
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
baberabb's avatar
baberabb committed
211
                        max_seq_len=self.max_length - 1,
baberabb's avatar
baberabb committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                        context_len=1,
                    ),
                )
            )

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods

    def generate_until(self, requests: List[Instance]) -> List[str]:
        res = defaultdict(list)
        re_ords = {}

        # batch tokenize contexts
        context, all_gen_kwargs = zip(*(req.args for req in requests))
236
        context_encoding = self.tokenizer(context, add_special_tokens=False).input_ids
baberabb's avatar
baberabb committed
237
238
239
        requests = [
            ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
        ]
baberabb's avatar
baberabb committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

        def _collate_gen(_requests):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            return -len(_requests[0][1]), tuple(_requests[0][1])

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        grouper = utils.Grouper(requests, lambda x: str(x[1]))
        for key, reqs in grouper.get_grouped().items():
            # within each set of reqs for given kwargs, we reorder by token length, descending.
            re_ords[key] = utils.Reorderer(requests, _collate_gen)

        pbar = tqdm(total=len(requests), disable=(self.rank != 0))
        # for each different set of kwargs, we execute all requests, by batch.
        for key, re_ord in re_ords.items():
            chunks = utils.chunks(
                re_ord.get_reordered(),
baberabb's avatar
baberabb committed
263
                n=self.batch_size if self.batch_size != "auto" else 0,
baberabb's avatar
baberabb committed
264
265
266
                fn=None,
            )
            for chunk in chunks:
baberabb's avatar
bugfix  
baberabb committed
267
                context_and_encoding, all_gen_kwargs = zip(*chunk)
baberabb's avatar
baberabb committed
268
                context, context_encoding = zip(*context_and_encoding)
baberabb's avatar
baberabb committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
                # we assume all gen kwargs in the batch are the same
                # this is safe to assume because the `grouper` object ensures it.
                gen_kwargs = all_gen_kwargs[0]
                # unpack our keyword arguments.
                until = None
                if isinstance(gen_kwargs, dict):
                    kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                    if "until" in kwargs.keys():
                        until = kwargs.pop("until")
                        if isinstance(until, str):
                            until = [until]
                        elif not isinstance(until, list):
                            raise ValueError(
                                f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                            )
                else:
                    raise ValueError(
                        f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
                    )
                if not until:
                    until = [self.tokenizer.decode(self.eot_token_id)]
                if "max_gen_toks" in kwargs.keys():
                    max_gen_toks = kwargs.pop("max_gen_toks")
                else:
                    max_gen_toks = self.max_gen_toks

                # set the max length in tokens of inputs ("context_enc")
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
                context_encoding = [x[-max_ctx_len:] for x in context_encoding]

                # TODO: max_length in kwargs

                # perform batched generation
                cont = self._model_generate(
                    requests=context_encoding,
                    generate=True,
                    max_tokens=max_gen_toks,
                    stop=until,
                    **kwargs,
                )

                # cache generations
                for output, context in zip(cont, context):
                    generated_text = output.outputs[0].text
                    res[key].append(generated_text)
                    self.cache_hook.add_partial(
                        "generate_until", (context, gen_kwargs), generated_text
                    )
                    pbar.update(1)

            # reorder this group of results back to original unsorted form
            res[key] = re_ord.get_original(res[key])

        pbar.close()

        return grouper.get_original(res)

    def _loglikelihood_tokens(
baberabb's avatar
baberabb committed
328
329
330
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
baberabb's avatar
baberabb committed
331
332
333
334
335
336
337
338
339
340
341
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        re_ord = utils.Reorderer(requests, _collate)

        chunks = utils.chunks(
            re_ord.get_reordered(),
baberabb's avatar
baberabb committed
342
            n=self.batch_size if self.batch_size != "auto" else 0,
baberabb's avatar
baberabb committed
343
344
            fn=None,
        )
baberabb's avatar
baberabb committed
345
        pbar = tqdm(total=len(requests), disable=disable_tqdm)
baberabb's avatar
baberabb committed
346
347
348
349
350
351
352
353
354
355
356
357
        for chunk in chunks:
            inps = []
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
                inp = (context_enc + continuation_enc)[-(self.max_length) :]
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length)
                )

                inps.append(inp)
                ctxlens.append(ctxlen)

baberabb's avatar
baberabb committed
358
            outputs = self._model_generate(requests=inps, generate=False)
baberabb's avatar
baberabb committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

            for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                outputs, ctxlens, chunk
            ):
                answer = self._parse_logprobs(
                    (context_enc + continuation_enc),
                    output,
                    ctxlen,
                )

                res.append(answer)

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
                    pbar.update(1)
        pbar.close()
        return re_ord.get_original(res)

    @staticmethod
baberabb's avatar
baberabb committed
379
    def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
baberabb's avatar
baberabb committed
380
381
382
        """Process logprobs and tokens.

        :param tokens: list
baberabb's avatar
baberabb committed
383
            Tokens from context+continuations
baberabb's avatar
bugfix  
baberabb committed
384
385
        :param outputs: RequestOutput
            Contains prompt
baberabb's avatar
baberabb committed
386
387
388
389
390
391
392
393
394
        :param ctxlen: int
            Length of context (so we can slice them away and only keep the predictions)
        :return:
            continuation_logprobs: float
                Log probabilities of continuation tokens
            is_greedy: bool
                Whether argmax matches given continuation exactly
        """

baberabb's avatar
baberabb committed
395
        # prompt_logprobs = [None, {}*len(context-1)]
baberabb's avatar
bugfix  
baberabb committed
396
397
        continuation_logprobs_dicts = outputs.prompt_logprobs

baberabb's avatar
baberabb committed
398
        # Calculate continuation_logprobs
baberabb's avatar
baberabb committed
399
        # assume ctxlen always > 1
baberabb's avatar
baberabb committed
400
        continuation_logprobs = sum(
baberabb's avatar
baberabb committed
401
            logprob_dict.get(token)
baberabb's avatar
baberabb committed
402
            for token, logprob_dict in zip(
baberabb's avatar
bugfix  
baberabb committed
403
                tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
baberabb's avatar
baberabb committed
404
405
406
407
408
            )
        )

        # Determine if is_greedy
        is_greedy = True
baberabb's avatar
baberabb committed
409
410
411
        for token, logprob_dict in zip(
            tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
        ):
baberabb's avatar
bugfix  
baberabb committed
412
413
414
415
416
417
            # Get the token with the maximum log probability from the logprob_dict
            if logprob_dict:  # Ensure the logprob_dict is not None
                top_token = max(logprob_dict, key=logprob_dict.get)
                if top_token != token:
                    is_greedy = False
                    break
baberabb's avatar
baberabb committed
418
419

        return continuation_logprobs, is_greedy