openai_completions.py 9.61 KB
Newer Older
Baber's avatar
types  
Baber committed
1
2
from __future__ import annotations

Lintang Sutawika's avatar
Lintang Sutawika committed
3
import logging
Jason Phang's avatar
gpt3  
Jason Phang committed
4
import os
Baber Abbasi's avatar
Baber Abbasi committed
5
from functools import cached_property
6
from operator import itemgetter
Baber's avatar
types  
Baber committed
7
from typing import Any
8

9
from lm_eval.api.registry import register_model
Baber Abbasi's avatar
Baber Abbasi committed
10
from lm_eval.models.api_models import TemplateAPI
11
from lm_eval.models.utils import handle_stop_sequences
Lintang Sutawika's avatar
Lintang Sutawika committed
12
13
14


eval_logger = logging.getLogger(__name__)
Leo Gao's avatar
Leo Gao committed
15

lintangsutawika's avatar
update  
lintangsutawika committed
16

Baber Abbasi's avatar
Baber Abbasi committed
17
18
@register_model("local-completions")
class LocalCompletionsAPI(TemplateAPI):
lintangsutawika's avatar
lintangsutawika committed
19
20
    def __init__(
        self,
21
22
        base_url: str = None,
        tokenizer_backend: str = "huggingface",
Baber Abbasi's avatar
Baber Abbasi committed
23
24
25
26
27
        **kwargs,
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
lintangsutawika's avatar
lintangsutawika committed
28

Baber Abbasi's avatar
Baber Abbasi committed
29
30
    def _create_payload(
        self,
Baber's avatar
types  
Baber committed
31
        messages: list[list[int]] | list[dict] | list[str] | str,
Baber Abbasi's avatar
Baber Abbasi committed
32
        generate=False,
Baber's avatar
types  
Baber committed
33
        gen_kwargs: dict | None = None,
34
        seed: int = 1234,
35
        eos=None,
Baber Abbasi's avatar
Baber Abbasi committed
36
37
38
39
        **kwargs,
    ) -> dict:
        if generate:
            gen_kwargs.pop("do_sample", False)
40
41
42
43
            if "max_tokens" in gen_kwargs:
                max_tokens = gen_kwargs.pop("max_tokens")
            else:
                max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
Baber Abbasi's avatar
Baber Abbasi committed
44
            temperature = gen_kwargs.pop("temperature", 0)
45
            stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
Baber Abbasi's avatar
Baber Abbasi committed
46
47
48
49
50
51
            return {
                "prompt": messages,
                "model": self.model,
                "max_tokens": max_tokens,
                "temperature": temperature,
                "stop": stop,
52
                "seed": seed,
Baber Abbasi's avatar
Baber Abbasi committed
53
54
                **gen_kwargs,
            }
Baber's avatar
types  
Baber committed
55
56
57
58
59
60
61
62
63
        return {
            "model": self.model,
            "prompt": messages,
            "temperature": 0,
            "max_tokens": 1,
            "logprobs": 1,
            "seed": seed,
            "echo": True,
        }
Baber Abbasi's avatar
Baber Abbasi committed
64
65
66

    @staticmethod
    def parse_logprobs(
Baber's avatar
types  
Baber committed
67
68
69
        outputs: dict | list[dict],
        tokens: list[list[int]] = None,
        ctxlens: list[int] = None,
Baber Abbasi's avatar
Baber Abbasi committed
70
        **kwargs,
Baber's avatar
types  
Baber committed
71
    ) -> list[tuple[float, bool]]:
lintangsutawika's avatar
lintangsutawika committed
72
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
73
74
75
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
76
77
78
            for choice, ctxlen in zip(
                sorted(out["choices"], key=itemgetter("index")), ctxlens
            ):
Baber Abbasi's avatar
Baber Abbasi committed
79
80
                assert ctxlen > 0, "Context length must be greater than 0"
                logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
81
                tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
Baber Abbasi's avatar
Baber Abbasi committed
82
83
                top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1]
                is_greedy = True
84
85
                for tok, top in zip(tokens_logprobs, top_logprobs):
                    if tok != max(top.values()):
Baber Abbasi's avatar
Baber Abbasi committed
86
87
88
89
90
91
                        is_greedy = False
                        break
                res.append((logprobs, is_greedy))
        return res

    @staticmethod
Baber's avatar
types  
Baber committed
92
    def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
lintangsutawika's avatar
lintangsutawika committed
93
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
94
95
96
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
97
            tmp = [None] * len(out["choices"])
Baber Abbasi's avatar
Baber Abbasi committed
98
            for choices in out["choices"]:
99
100
                tmp[choices["index"]] = choices["text"]
            res = res + tmp
Baber Abbasi's avatar
Baber Abbasi committed
101
        return res
lintangsutawika's avatar
lintangsutawika committed
102

Baber Abbasi's avatar
Baber Abbasi committed
103
104
105
    @property
    def api_key(self):
        return os.environ.get("OPENAI_API_KEY", "")
lintangsutawika's avatar
lintangsutawika committed
106
107


Baber Abbasi's avatar
Baber Abbasi committed
108
109
110
111
@register_model("local-chat-completions")
class LocalChatCompletion(LocalCompletionsAPI):
    def __init__(
        self,
112
113
114
        base_url: str = None,
        tokenizer_backend: str = None,
        tokenized_requests: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
115
116
        **kwargs,
    ):
117
118
119
        eval_logger.warning(
            "chat-completions endpoint requires the `--apply_chat_template` flag."
        )
Baber Abbasi's avatar
Baber Abbasi committed
120
121
122
123
124
125
126
127
128
        super().__init__(
            base_url=base_url,
            tokenizer_backend=tokenizer_backend,
            tokenized_requests=tokenized_requests,
            **kwargs,
        )
        if self._batch_size > 1:
            eval_logger.warning(
                "Chat completions does not support batching. Defaulting to batch size 1."
lintangsutawika's avatar
lintangsutawika committed
129
            )
Baber Abbasi's avatar
Baber Abbasi committed
130
131
132
            self._batch_size = 1

    def _create_payload(
133
        self,
Baber's avatar
types  
Baber committed
134
        messages: list[dict],
135
        generate=False,
Baber's avatar
types  
Baber committed
136
        gen_kwargs: dict | None = None,
137
        seed=1234,
138
        eos=None,
139
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
140
    ) -> dict:
Baber Abbasi's avatar
Baber Abbasi committed
141
142
143
        assert type(messages) is not str, (
            "chat-completions require the --apply_chat_template flag."
        )
Baber Abbasi's avatar
Baber Abbasi committed
144
        gen_kwargs.pop("do_sample", False)
145
146
147
148
        if "max_tokens" in gen_kwargs:
            max_tokens = gen_kwargs.pop("max_tokens")
        else:
            max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
Baber Abbasi's avatar
Baber Abbasi committed
149
        temperature = gen_kwargs.pop("temperature", 0)
150
        stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
Baber Abbasi's avatar
Baber Abbasi committed
151
152
153
154
155
156
157
158
        if not isinstance(stop, (list, tuple)):
            stop = [stop]
        return {
            "messages": messages,
            "model": self.model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stop": stop[:4],
159
            "seed": seed,
Baber Abbasi's avatar
Baber Abbasi committed
160
161
162
163
            **gen_kwargs,
        }

    @staticmethod
Baber's avatar
types  
Baber committed
164
    def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
Baber Abbasi's avatar
Baber Abbasi committed
165
166
167
168
        res = []
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
169
            tmp = [None] * len(out["choices"])
Baber Abbasi's avatar
Baber Abbasi committed
170
            for choices in out["choices"]:
171
172
                tmp[choices["index"]] = choices["message"]["content"]
            res = res + tmp
Baber Abbasi's avatar
Baber Abbasi committed
173
174
175
176
        return res

    def tok_encode(
        self,
Baber's avatar
types  
Baber committed
177
        string: str | Any,
Baber Abbasi's avatar
Baber Abbasi committed
178
179
180
        left_truncate_len=None,
        add_special_tokens=None,
        **kwargs,
Baber's avatar
types  
Baber committed
181
    ) -> list[str] | list[int] | Any:
Baber Abbasi's avatar
Baber Abbasi committed
182
        return string
lintangsutawika's avatar
lintangsutawika committed
183

Baber Abbasi's avatar
Baber Abbasi committed
184
    def loglikelihood(self, requests, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
185
186
187
        raise NotImplementedError(
            "Loglikelihood is not supported for chat completions. Consider using the completions API instead."
        )
lintangsutawika's avatar
lintangsutawika committed
188
189


Baber Abbasi's avatar
Baber Abbasi committed
190
191
192
193
@register_model(
    "openai-completions",
)
class OpenAICompletionsAPI(LocalCompletionsAPI):
194
    def __init__(
195
        self,
Baber Abbasi's avatar
Baber Abbasi committed
196
197
        base_url="https://api.openai.com/v1/completions",
        tokenizer_backend="tiktoken",
198
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
199
200
201
202
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
203

Baber Abbasi's avatar
Baber Abbasi committed
204
205
206
207
208
209
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("OPENAI_API_KEY", None)
        if key is None:
            raise ValueError(
210
                "API key not found. Please set the `OPENAI_API_KEY` environment variable."
211
            )
Baber Abbasi's avatar
Baber Abbasi committed
212
        return key
213

Baber Abbasi's avatar
Baber Abbasi committed
214
    def loglikelihood(self, requests, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
215
216
217
218
219
220
        assert self.model in [
            "babbage-002",
            "davinci-002",
        ], (
            f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
        )
Baber Abbasi's avatar
Baber Abbasi committed
221
        return super().loglikelihood(requests, **kwargs)
222

Baber's avatar
types  
Baber committed
223
    def chat_template(self, chat_template: bool | str = False) -> str | None:
224
225
        return ""

226

Baber Abbasi's avatar
Baber Abbasi committed
227
@register_model("openai-chat-completions")
Baber Abbasi's avatar
Baber Abbasi committed
228
229
230
231
232
233
234
235
class OpenAIChatCompletion(LocalChatCompletion):
    def __init__(
        self,
        base_url="https://api.openai.com/v1/chat/completions",
        tokenizer_backend=None,
        tokenized_requests=False,
        **kwargs,
    ):
236
237
238
239
        if "o1" in kwargs.get("model", ""):
            eval_logger.warning(
                "o1 models do not support `stop` and only support temperature=1"
            )
240

Baber Abbasi's avatar
Baber Abbasi committed
241
242
243
244
245
246
        super().__init__(
            base_url=base_url,
            tokenizer_backend=tokenizer_backend,
            tokenized_requests=tokenized_requests,
            **kwargs,
        )
247

Baber Abbasi's avatar
Baber Abbasi committed
248
249
250
251
252
253
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("OPENAI_API_KEY", None)
        if key is None:
            raise ValueError(
254
                "API key not found. Please set the `OPENAI_API_KEY` environment variable."
255
            )
Baber Abbasi's avatar
Baber Abbasi committed
256
        return key
257
258
259
260
261

    def loglikelihood(self, requests, **kwargs):
        raise NotImplementedError(
            "Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
        )
262
263
264

    def _create_payload(
        self,
Baber's avatar
types  
Baber committed
265
        messages: list[dict],
266
267
268
        generate=False,
        gen_kwargs: dict = None,
        seed=1234,
269
        eos="<|endoftext|>",
270
271
        **kwargs,
    ) -> dict:
Baber Abbasi's avatar
Baber Abbasi committed
272
273
274
        assert type(messages) is not str, (
            "chat-completions require the --apply_chat_template flag."
        )
275
276
277
278
279
280
        gen_kwargs.pop("do_sample", False)
        if "max_tokens" in gen_kwargs:
            max_tokens = gen_kwargs.pop("max_tokens")
        else:
            max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
        temperature = gen_kwargs.pop("temperature", 0)
281
        stop = handle_stop_sequences(gen_kwargs.pop("until", ["<|endoftext|>"]), eos)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        if not isinstance(stop, (list, tuple)):
            stop = [stop]
        output = {
            "messages": messages,
            "model": self.model,
            "max_completion_tokens": max_tokens,
            "temperature": temperature,
            "stop": stop[:4],
            "seed": seed,
            **gen_kwargs,
        }
        if "o1" in self.model:
            output.pop("stop")
            output["temperature"] = 1
Jocelyn's avatar
Jocelyn committed
296
297
        elif "o3" in self.model:
            output.pop("temperature")
298
        return output