openai_completions.py 15.6 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
baberabb's avatar
baberabb committed
3
from typing import List, Tuple
Leo Gao's avatar
Leo Gao committed
4
from tqdm import tqdm
lintangsutawika's avatar
lintangsutawika committed
5
from lm_eval import utils
6
7
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Leo Gao's avatar
Leo Gao committed
8
9


baberabb's avatar
baberabb committed
10
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
34

Leo Gao's avatar
Leo Gao committed
35
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
36
37


Leo Gao's avatar
Leo Gao committed
38
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
39
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
40

41
42
    Retry with back-off until they respond
    """
43
44
45
46
47
48
49
    try:
        import openai, tiktoken  # noqa: E401
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
50

Leo Gao's avatar
Leo Gao committed
51
52
53
54
55
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
56
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
57

Leo Gao's avatar
Leo Gao committed
58
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
59
60
61
62
            time.sleep(backoff_time)
            backoff_time *= 1.5


63
def oa_chat_completion(**kwargs):
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    """Query OpenAI API for chat completion.

    Retry with back-off until they respond
    """
    try:
        import openai, tiktoken  # noqa: E401
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
        )

    backoff_time = 3
    while True:
        try:
79
            return openai.ChatCompletion.create(**kwargs)
80
81
82
83
84
85
86
87
        except openai.error.OpenAIError:
            import traceback

            traceback.print_exc()
            time.sleep(backoff_time)
            backoff_time *= 1.5


haileyschoelkopf's avatar
haileyschoelkopf committed
88
@register_model("openai", "openai-completions", "gooseai")
haileyschoelkopf's avatar
haileyschoelkopf committed
89
class OpenaiCompletionsLM(LM):
Leo Gao's avatar
Leo Gao committed
90
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
91

baberabb's avatar
baberabb committed
92
93
94
95
96
    def __init__(
        self,
        engine: str = "text-davinci-003",
        truncate: bool = False,
        batch_size: int = 1,
Ethan Smith's avatar
Ethan Smith committed
97
    ) -> None:
Jason Phang's avatar
Jason Phang committed
98
99
100
101
102
103
104
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
105
        super().__init__()
106
107
108
109
110
111
112
        try:
            import openai, tiktoken  # noqa: E401
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
Jason Phang's avatar
gpt3  
Jason Phang committed
113
        self.engine = engine
114
115
        self.tokenizer = tiktoken.encoding_for_model(self.engine)
        self.vocab_size = self.tokenizer.n_vocab
Jason Phang's avatar
Jason Phang committed
116
        self.truncate = truncate
117
        self.end_of_text_token_id = self.tokenizer.eot_token
Jason Phang's avatar
Jason Phang committed
118

Jason Phang's avatar
gpt3  
Jason Phang committed
119
120
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
121
122
123

    @property
    def eot_token_id(self):
124
        return self.end_of_text_token_id
125
126

    @property
Ethan Smith's avatar
Ethan Smith committed
127
    def max_length(self) -> int:
128
129
130
131
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
Ethan Smith's avatar
Ethan Smith committed
132
    def max_gen_toks(self) -> int:
133
134
135
136
137
138
139
140
141
142
143
144
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

baberabb's avatar
baberabb committed
145
    def tok_encode(self, string: str) -> List[int]:
146
        return self.tokenizer.encode(string)
Fabrizio Milo's avatar
Fabrizio Milo committed
147

baberabb's avatar
baberabb committed
148
    def tok_decode(self, tokens: List[int]) -> str:
149
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
150

baberabb's avatar
baberabb committed
151
152
153
154
155
156
157
158
159
160
161
162
163
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

164
    def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

179
    def _loglikelihood_tokens(
Ethan Smith's avatar
Ethan Smith committed
180
        self, requests, disable_tqdm: bool = False
181
    ) -> List[Tuple[float, bool]]:
Leo Gao's avatar
Leo Gao committed
182
183
        res = []

184
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
185
186
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
187
            # we care about, and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
188
            toks = x[1] + x[2]
189
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
190

Fabrizio Milo's avatar
Fabrizio Milo committed
191
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
192

Fabrizio Milo's avatar
Fabrizio Milo committed
193
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
194
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
195
196
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
197
198
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
199
            for cache_key, context_enc, continuation_enc in chunk:
200
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
201
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
202
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
203
204
205
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
206
207
208
209

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
210
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
211
212
213
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
214
215
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
216
217
218
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
219
220
221
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
222
223
224
225
226
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
227
228
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Fabrizio Milo's avatar
Fabrizio Milo committed
229
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
230

231
    def generate_until(self, requests) -> List[str]:
232
233
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
234
        res = []
235
        requests = [req.args for req in requests]
Leo Gao's avatar
Leo Gao committed
236

237
        def _collate(x):
238
            toks = self.tok_encode(x[0])
239
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
240

Fabrizio Milo's avatar
Fabrizio Milo committed
241
        re_ord = utils.Reorderer(requests, _collate)
242

Leo Gao's avatar
Leo Gao committed
243
244
245
246
247
248
249
250
251
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
252

253
254
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
255

256
        # todo: more intelligent batching for heterogeneous `until`
257
        for chunk, request_args in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
258
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
259
        ):
Leo Gao's avatar
Leo Gao committed
260
261
            inps = []
            for context, _ in chunk:
262
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
263
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
264
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
265

266
            until = request_args.get("until", ["<|endoftext|>"])
267

Leo Gao's avatar
Leo Gao committed
268
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
269
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
270
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
271
272
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
273
                logprobs=10,
274
                stop=until,
Leo Gao's avatar
Leo Gao committed
275
            )
Leo Gao's avatar
Leo Gao committed
276

277
            for resp, (context, args_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
278
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
279

280
                until_ = args_.get("until", ["<|endoftext|>"])
281

282
                for term in until_:
283
284
                    if len(term) > 0:
                        s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
285

Leo Gao's avatar
Leo Gao committed
286
                # partial caching
287
                self.cache_hook.add_partial(
288
                    "generate_until", (context, {"until": until_}), s
289
                )
Fabrizio Milo's avatar
Fabrizio Milo committed
290

Leo Gao's avatar
Leo Gao committed
291
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
292
        return re_ord.get_original(res)
293
294
295
296
297
298

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
299
        # Isn't used because we override generate_until
300
        raise NotImplementedError()
baberabb's avatar
baberabb committed
301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def loglikelihood_rolling(self, requests) -> List[float]:
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
                disable_tqdm=True,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods
332
333


334
335
336
337
@register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(LM):
    REQ_CHUNK_SIZE = 20

338
    def __init__(
339
            self, engine: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
340
    ) -> None:
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        """

        :param engine: str
            OpenAI API engine (e.g. gpt-3.5-turbo)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        super().__init__()
        try:
            import openai, tiktoken  # noqa: E401
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
        self.engine = engine
        self.tokenizer = tiktoken.encoding_for_model(self.engine)
        self.vocab_size = self.tokenizer.n_vocab
        self.truncate = truncate
        self.end_of_text_token_id = self.tokenizer.eot_token

        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]

    @property
    def eot_token_id(self):
        return self.end_of_text_token_id

    @property
    def max_length(self) -> int:
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self) -> int:
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string)

    def tok_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)

    def _encode_pair(
            self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc
406

407
    def generate_until(self, requests) -> List[str]:
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        if not requests:
            return []
        res = []
        requests = [req.args for req in requests]

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer(requests, _collate)

        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)

            if ret:
                yield ret, lastuntil

        # todo: more intelligent batching for heterogeneous `until`
        for chunk, request_args in tqdm(
434
                list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
435
436
437
438
        ):
            inps = []
            for context, _ in chunk:
                context_enc = self.tok_encode(context)
439
                inp = context_enc[-(self.max_length - self.max_gen_toks):]
440
441
442
443
                inps.append({"role": "user", "content": inp})

            until = request_args.get("until", ["<|endoftext|>"])

444
            response = oa_chat_completion(
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
                engine=self.engine,
                prompt=inps,
                max_tokens=self.max_gen_toks,
                temperature=0.0,
                logprobs=10,
                stop=until,
            )

            for resp, (context, args_) in zip(response.choices, chunk):
                s = resp["text"]

                until_ = args_.get("until", ["<|endoftext|>"])

                for term in until_:
                    if len(term) > 0:
                        s = s.split(term)[0]

                # partial caching
                self.cache_hook.add_partial(
464
                    "generate_until", (context, {"until": until_}), s
465
466
467
468
                )

                res.append(s)
        return re_ord.get_original(res)
469
470
471
472
473
474

    def loglikelihood(self, requests):
        raise NotImplementedError("No support for logits.")

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError("No support for logits.")