openai_completions.py 18.5 KB
Newer Older
1
import copy
lintangsutawika's avatar
update  
lintangsutawika committed
2
from collections import defaultdict
3
from importlib.util import find_spec
4
from typing import List, Literal, Optional, Tuple
5

6
import transformers
Leo Gao's avatar
Leo Gao committed
7
from tqdm import tqdm
lintangsutawika's avatar
update  
lintangsutawika committed
8

lintangsutawika's avatar
lintangsutawika committed
9
from lm_eval import utils
10
11
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
12
from lm_eval.utils import retry_on_specific_exceptions
Leo Gao's avatar
Leo Gao committed
13

lintangsutawika's avatar
update  
lintangsutawika committed
14

Baber Abbasi's avatar
Baber Abbasi committed
15
def get_result(response, ctxlen: int) -> Tuple[float, bool]:
lintangsutawika's avatar
lintangsutawika committed
16
17
18
19
20
21
22
23
24
25
26
27
28
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
    is_greedy = True
Baber Abbasi's avatar
Baber Abbasi committed
29
    logprobs = response.logprobs.token_logprobs
lintangsutawika's avatar
lintangsutawika committed
30
31
    continuation_logprobs = sum(logprobs[ctxlen:])

Baber Abbasi's avatar
Baber Abbasi committed
32
33
34
    for i in range(ctxlen, len(response.logprobs.token_logprobs)):
        token = response.logprobs.token_logprobs[i]
        top_tokens = response.logprobs.top_logprobs[i]
lintangsutawika's avatar
lintangsutawika committed
35
36
37
38
39
40
41
42
43
44
45
46
47
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break

    return continuation_logprobs, is_greedy


def oa_completion(**kwargs):
    """Query OpenAI API for completion.

    Retry with back-off until they respond
    """
48
    if not find_spec("openai") or not find_spec("tiktoken"):
lintangsutawika's avatar
lintangsutawika committed
49
        raise Exception(
50
51
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
            "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
lintangsutawika's avatar
lintangsutawika committed
52
        )
53
54
    else:
        import openai
lintangsutawika's avatar
lintangsutawika committed
55

56
57
58
59
60
61
62
63
64
65
66
67
    def _exception_callback(e: Exception, sleep_time: float) -> None:
        import traceback

        traceback.print_exc()

    @retry_on_specific_exceptions(
        on_exceptions=[openai.OpenAIError],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def completion():
        return openai.completions.create(**kwargs)
lintangsutawika's avatar
lintangsutawika committed
68

69
    return completion()
lintangsutawika's avatar
lintangsutawika committed
70
71


72
@register_model("openai-completions", "local-completions")
lintangsutawika's avatar
lintangsutawika committed
73
74
class OpenaiCompletionsLM(LM):
    REQ_CHUNK_SIZE = 20
Baber Abbasi's avatar
Baber Abbasi committed
75
    _DEFAULT_MAX_LENGTH = 2048
lintangsutawika's avatar
lintangsutawika committed
76
77
78

    def __init__(
        self,
79
80
81
82
        model: str = "gpt-3.5-turbo-instruct",
        tokenizer_backend: Literal["tiktoken", "huggingface"] = "tiktoken",
        batch_size=1,
        base_url: str = None,
lintangsutawika's avatar
lintangsutawika committed
83
        truncate: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
84
85
86
        max_gen_toks: int = 256,
        seed: int = 1234,
        max_length: Optional[int] = None,
87
88
89
        revision: Optional[str] = "main",
        trust_remote_code: Optional[bool] = False,
        use_fast_tokenizer: Optional[bool] = True,
lintangsutawika's avatar
lintangsutawika committed
90
91
92
    ) -> None:
        """

93
94
95
96
97
98
        :param model: str
            Implements an OpenAI-style chat completion API for
            accessing both OpenAI OR locally-hosted models using
            HuggingFace Tokenizer
            OpenAI API model (e.g. gpt-3.5-turbo)
            using the **gen_kwargs passed on init
lintangsutawika's avatar
lintangsutawika committed
99
100
101
102
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        super().__init__()
Baber Abbasi's avatar
Baber Abbasi committed
103
        self.seed = seed
lintangsutawika's avatar
lintangsutawika committed
104
        try:
105
106
            import openai  # noqa: E401
            import tiktoken
lintangsutawika's avatar
lintangsutawika committed
107
108
109
110
111
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
Baber Abbasi's avatar
Baber Abbasi committed
112
        self.model = model
113
114
        self.base_url = base_url
        self.tokenizer_backend = tokenizer_backend
lintangsutawika's avatar
lintangsutawika committed
115
        self.truncate = truncate
Baber Abbasi's avatar
Baber Abbasi committed
116
117
        self._max_gen_toks = max_gen_toks
        self._max_length = max_length
lintangsutawika's avatar
lintangsutawika committed
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        # if we have a local model, use HF tokenizer over tiktoken
        if self.tokenizer_backend == "huggingface":
            self.revision = revision
            self.trust_remote_code = trust_remote_code
            self.use_fast_tokenizer = use_fast_tokenizer

            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                self.model,
                revision=self.revision,
                trust_remote_code=self.trust_remote_code,
                use_fast_tokenizer=self.use_fast_tokenizer,
            )
            self.vocab_size = self.tokenizer.vocab
            self.end_of_text_token_id = self.tokenizer.eos_token
        elif self.tokenizer_backend == "tiktoken":
            self.tokenizer = tiktoken.encoding_for_model(self.model)
            self.vocab_size = self.tokenizer.n_vocab
            self.end_of_text_token_id = self.tokenizer.eot_token
        else:
            raise ValueError(
                f"Expected tokenizer_backend to be one of ['tiktoken', 'huggingface'] but got {self.tokenizer_backend}"
            )

142
        # Read from environment variable OPENAI_API_KEY
143
144
145
146
147
        # Set to EMPTY for local
        if self.base_url:
            self.client = openai.OpenAI(base_url=self.base_url)
        else:
            self.client = openai.OpenAI()  # openai.AsyncOpenAI()
lintangsutawika's avatar
lintangsutawika committed
148
149
150
151
152
153
154

    @property
    def eot_token_id(self):
        return self.end_of_text_token_id

    @property
    def max_length(self) -> int:
Baber Abbasi's avatar
Baber Abbasi committed
155
156
157
158
        if self._max_length:
            return self._max_length
        else:
            return self._DEFAULT_MAX_LENGTH
lintangsutawika's avatar
lintangsutawika committed
159
160
161

    @property
    def max_gen_toks(self) -> int:
Baber Abbasi's avatar
Baber Abbasi committed
162
        return self._max_gen_toks
lintangsutawika's avatar
lintangsutawika committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string)

    def tok_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)

    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

    def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
198
199
200
                context_enc, continuation_enc = (
                    [self.eot_token_id],
                    self.tok_encode(continuation),
lintangsutawika's avatar
lintangsutawika committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def _loglikelihood_tokens(
        self, requests, disable_tqdm: bool = False
    ) -> List[Tuple[float, bool]]:
        res = []

        def _collate(x):
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about, and so we need some kind of backup for when it isn't
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        re_ord = utils.Reorderer(requests, _collate)

        for chunk in tqdm(
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
            disable=disable_tqdm,
        ):
            inps = []
            ctxlens = []
            for cache_key, context_enc, continuation_enc in chunk:
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )

                inps.append(inp)
                ctxlens.append(ctxlen)

            response = oa_completion(
Baber Abbasi's avatar
Baber Abbasi committed
241
                model=self.model,
lintangsutawika's avatar
lintangsutawika committed
242
243
244
245
246
                prompt=inps,
                echo=True,
                max_tokens=0,
                temperature=0.0,
                logprobs=10,
Baber Abbasi's avatar
Baber Abbasi committed
247
                seed=self.seed,
lintangsutawika's avatar
lintangsutawika committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            )

            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
        return re_ord.get_original(res)

    def generate_until(self, requests) -> List[str]:
        if not requests:
            return []
        res = []
        requests = [req.args for req in requests]

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer(requests, _collate)

        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)

            if ret:
                yield ret, lastuntil

        # todo: more intelligent batching for heterogeneous `until`
        for chunk, request_args in tqdm(
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
        ):
            inps = []
Baber Abbasi's avatar
Baber Abbasi committed
292
            self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
lintangsutawika's avatar
lintangsutawika committed
293
294
295
296
297
            for context, _ in chunk:
                context_enc = self.tok_encode(context)
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
                inps.append(inp)

Baber Abbasi's avatar
Baber Abbasi committed
298
299
300
            until = request_args.pop("until", ["<|endoftext|>"])
            request_args.pop("do_sample", None)
            request_args["temperature"] = request_args.get("temperature", 0)
lintangsutawika's avatar
lintangsutawika committed
301
302

            response = oa_completion(
303
                model=self.model,
lintangsutawika's avatar
lintangsutawika committed
304
305
306
                prompt=inps,
                max_tokens=self.max_gen_toks,
                stop=until,
Baber Abbasi's avatar
Baber Abbasi committed
307
308
                seed=self.seed,
                **request_args,
lintangsutawika's avatar
lintangsutawika committed
309
310
            )
            for resp, (context, args_) in zip(response.choices, chunk):
Baber Abbasi's avatar
Baber Abbasi committed
311
                s = getattr(resp, "text")
lintangsutawika's avatar
lintangsutawika committed
312

Baber Abbasi's avatar
Baber Abbasi committed
313
                until_ = until
lintangsutawika's avatar
lintangsutawika committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

                for term in until_:
                    if len(term) > 0:
                        s = s.split(term)[0]

                # partial caching
                self.cache_hook.add_partial(
                    "generate_until", (context, {"until": until_}), s
                )

                res.append(s)
        return re_ord.get_original(res)

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override generate_until
        raise NotImplementedError()

    def loglikelihood_rolling(self, requests) -> List[float]:
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
                disable_tqdm=True,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods


367
def oa_chat_completion(client, **kwargs):
368
369
370
371
    """Query OpenAI API for chat completion.

    Retry with back-off until they respond
    """
372
    if not find_spec("openai") or not find_spec("tiktoken"):
373
        raise Exception(
374
375
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
            "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
376
        )
377
378
    else:
        import openai
379

380
381
    def _exception_callback(e: Exception, sleep_time: float) -> None:
        import traceback
382

383
384
385
386
387
388
389
390
391
        traceback.print_exc()

    @retry_on_specific_exceptions(
        on_exceptions=[openai.OpenAIError],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def completion():
        return client.chat.completions.create(**kwargs)
392

393
    return completion()
394
395


396
@register_model("openai-chat-completions", "local-chat-completions")
397
class OpenaiChatCompletionsLM(LM):
398
    def __init__(
399
400
401
402
403
        self,
        model: str = "gpt-3.5-turbo",  # GPT model or Local model using HuggingFace model paths
        base_url: str = None,
        truncate: bool = False,
        **kwargs,
404
    ) -> None:
405
406
        """

lintangsutawika's avatar
lintangsutawika committed
407
        :param model: str
408
409
410
            Implements an OpenAI-style chat completion API for
            accessing both OpenAI OR locally-hosted models using
            HuggingFace Tokenizer
lintangsutawika's avatar
lintangsutawika committed
411
            OpenAI API model (e.g. gpt-3.5-turbo)
412
            using the **gen_kwargs passed on init
413
414
415
416
417
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        super().__init__()
        try:
418
            import openai  # noqa: E401
419
420
421
422
423
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
lintangsutawika's avatar
lintangsutawika committed
424
        self.model = model
425
        self.base_url = base_url
426
        self.truncate = truncate
427

428
        # Read from environment variable OPENAI_API_KEY
429
430
431
432
433
        # Set to EMPTY for local
        if self.base_url:
            self.client = openai.OpenAI(base_url=self.base_url)
        else:
            self.client = openai.OpenAI()  # openai.AsyncOpenAI()
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

    @property
    def max_length(self) -> int:
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self) -> int:
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

454
    def generate_until(self, requests) -> List[str]:
lintangsutawika's avatar
update  
lintangsutawika committed
455
456
        res = defaultdict(list)
        re_ords = {}
457

lintangsutawika's avatar
update  
lintangsutawika committed
458
459
460
461
462
463
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
        for key, reqs in grouper.get_grouped().items():
            # within each set of reqs for given kwargs, we reorder by token length, descending.
464
465
466
            re_ords[key] = utils.Reorderer(
                [req.args for req in reqs], lambda x: (-len(x[0]), x[0])
            )
467

lintangsutawika's avatar
update  
lintangsutawika committed
468
469
        pbar = tqdm(total=len(requests), disable=(self.rank != 0))
        for key, re_ord in re_ords.items():
470
471
            # n needs to be 1 because messages in
            # chat completion are not batch but
472
473
            # is regarded as a single conversation.
            chunks = utils.chunks(re_ord.get_reordered(), n=1)
lintangsutawika's avatar
update  
lintangsutawika committed
474
475
476
477
            for chunk in chunks:
                contexts, all_gen_kwargs = zip(*chunk)
                inps = [{"role": "user", "content": context} for context in contexts]

478
479
                gen_kwargs = all_gen_kwargs[0]
                until = None
Baber Abbasi's avatar
Baber Abbasi committed
480
                if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
481
482
                    if "do_sample" in kwargs.keys():
                        kwargs.pop("do_sample")
483
484
485
486
487
488
                    if "until" in kwargs.keys():
                        until = kwargs.pop("until")
                        if isinstance(until, str):
                            until = [kwargs]
                        elif not isinstance(until, list):
                            raise ValueError(
489
                                f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
490
                            )
Baber Abbasi's avatar
Baber Abbasi committed
491
492
                        kwargs["stop"] = until
                    kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
493
494
                else:
                    raise ValueError(
495
                        f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
496
497
498
                    )

                response = oa_chat_completion(
499
                    client=self.client, messages=inps, model=self.model, **kwargs
lintangsutawika's avatar
update  
lintangsutawika committed
500
                )
501

502
503
                for resp, (context, args_) in zip(response.choices, chunk):
                    s = resp.message.content
504

505
506
507
508
                    if until is not None:
                        for term in until:
                            if len(term) > 0:
                                s = s.split(term)[0]
lintangsutawika's avatar
update  
lintangsutawika committed
509

510
                    res[key].append(s)
lintangsutawika's avatar
update  
lintangsutawika committed
511

512
513
514
515
516
                    self.cache_hook.add_partial(
                        "generate_until", (context, {"until": until}), s
                    )
                    pbar.update(1)
            # reorder this group of results back to original unsorted form
lintangsutawika's avatar
update  
lintangsutawika committed
517
518
519
            res[key] = re_ord.get_original(res[key])

        pbar.close()
520

lintangsutawika's avatar
update  
lintangsutawika committed
521
        return grouper.get_original(res)
522
523
524
525
526
527

    def loglikelihood(self, requests):
        raise NotImplementedError("No support for logits.")

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError("No support for logits.")