openai_completions.py 10.2 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
baberabb's avatar
baberabb committed
3
from typing import List, Tuple
Leo Gao's avatar
Leo Gao committed
4
from tqdm import tqdm
lintangsutawika's avatar
lintangsutawika committed
5
from lm_eval import utils
6
7
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Leo Gao's avatar
Leo Gao committed
8
9


baberabb's avatar
baberabb committed
10
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
34

Leo Gao's avatar
Leo Gao committed
35
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
36
37


Leo Gao's avatar
Leo Gao committed
38
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
39
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
40

41
42
    Retry with back-off until they respond
    """
43
44
45
46
47
48
49
    try:
        import openai, tiktoken  # noqa: E401
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
50

Leo Gao's avatar
Leo Gao committed
51
52
53
54
55
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
56
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
57

Leo Gao's avatar
Leo Gao committed
58
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
59
60
61
62
            time.sleep(backoff_time)
            backoff_time *= 1.5


haileyschoelkopf's avatar
haileyschoelkopf committed
63
@register_model("openai", "openai-completions", "gooseai")
haileyschoelkopf's avatar
haileyschoelkopf committed
64
class OpenaiCompletionsLM(LM):
Leo Gao's avatar
Leo Gao committed
65
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
66

baberabb's avatar
baberabb committed
67
68
69
70
71
72
    def __init__(
        self,
        engine: str = "text-davinci-003",
        truncate: bool = False,
        batch_size: int = 1,
    ):
Jason Phang's avatar
Jason Phang committed
73
74
75
76
77
78
79
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
80
        super().__init__()
81
82
83
84
85
86
87
        try:
            import openai, tiktoken  # noqa: E401
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
Jason Phang's avatar
gpt3  
Jason Phang committed
88
        self.engine = engine
89
90
        self.tokenizer = tiktoken.encoding_for_model(self.engine)
        self.vocab_size = self.tokenizer.n_vocab
Jason Phang's avatar
Jason Phang committed
91
        self.truncate = truncate
92
        self.end_of_text_token_id = self.tokenizer.eot_token
Jason Phang's avatar
Jason Phang committed
93

Jason Phang's avatar
gpt3  
Jason Phang committed
94
95
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
96
97
98

    @property
    def eot_token_id(self):
99
        return self.end_of_text_token_id
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

baberabb's avatar
baberabb committed
120
    def tok_encode(self, string: str) -> List[int]:
121
        return self.tokenizer.encode(string)
Fabrizio Milo's avatar
Fabrizio Milo committed
122

baberabb's avatar
baberabb committed
123
    def tok_decode(self, tokens: List[int]) -> str:
124
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
125

baberabb's avatar
baberabb committed
126
127
128
129
130
131
132
133
134
135
136
137
138
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

139
    def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

154
155
156
    def _loglikelihood_tokens(
        self, requests, disable_tqdm=False
    ) -> List[Tuple[float, bool]]:
Leo Gao's avatar
Leo Gao committed
157
158
        res = []

159
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
160
161
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
162
            # we care about, and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
163
            toks = x[1] + x[2]
164
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
165

Fabrizio Milo's avatar
Fabrizio Milo committed
166
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
167

Fabrizio Milo's avatar
Fabrizio Milo committed
168
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
169
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
170
171
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
172
173
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
174
            for cache_key, context_enc, continuation_enc in chunk:
175
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
176
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
177
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
178
179
180
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
181
182
183
184

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
185
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
186
187
188
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
189
190
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
191
192
193
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
194
195
196
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
197
198
199
200
201
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
202
203
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Fabrizio Milo's avatar
Fabrizio Milo committed
204
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
205

baberabb's avatar
baberabb committed
206
    def greedy_until(self, requests) -> List[str]:
207
208
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
209
        res = []
210
        requests = [req.args for req in requests]
Leo Gao's avatar
Leo Gao committed
211

212
        def _collate(x):
213
            toks = self.tok_encode(x[0])
214
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
215

Fabrizio Milo's avatar
Fabrizio Milo committed
216
        re_ord = utils.Reorderer(requests, _collate)
217

Leo Gao's avatar
Leo Gao committed
218
219
220
221
222
223
224
225
226
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
227

228
229
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
230

231
        # todo: more intelligent batching for heterogeneous `until`
232
        for chunk, request_args in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
233
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
234
        ):
Leo Gao's avatar
Leo Gao committed
235
236
            inps = []
            for context, _ in chunk:
237
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
238
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
239
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
240

241
            until = request_args.get("until", ["<|endoftext|>"])
242

Leo Gao's avatar
Leo Gao committed
243
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
244
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
245
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
246
247
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
248
                logprobs=10,
249
                stop=until,
Leo Gao's avatar
Leo Gao committed
250
            )
Leo Gao's avatar
Leo Gao committed
251

252
            for resp, (context, args_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
253
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
254

255
                until_ = args_.get("until", ["<|endoftext|>"])
256

257
                for term in until_:
258
259
                    if len(term) > 0:
                        s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
260

Leo Gao's avatar
Leo Gao committed
261
                # partial caching
262
263
264
                self.cache_hook.add_partial(
                    "greedy_until", (context, {"until": until_}), s
                )
Fabrizio Milo's avatar
Fabrizio Milo committed
265

Leo Gao's avatar
Leo Gao committed
266
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
267
        return re_ord.get_original(res)
268
269
270
271
272
273
274
275

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()
baberabb's avatar
baberabb committed
276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    def loglikelihood_rolling(self, requests) -> List[float]:
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
                disable_tqdm=True,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods