openai_completions.py 9.33 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Baber Abbasi's avatar
Baber Abbasi committed
2
3
from functools import cached_property
from typing import Any, Dict, List, Optional, Tuple, Union
4

5
from lm_eval.api.registry import register_model
Baber Abbasi's avatar
Baber Abbasi committed
6
from lm_eval.models.api_models import TemplateAPI
7
from lm_eval.models.utils import handle_stop_sequences
8
from lm_eval.utils import eval_logger
Leo Gao's avatar
Leo Gao committed
9

lintangsutawika's avatar
update  
lintangsutawika committed
10

Baber Abbasi's avatar
Baber Abbasi committed
11
12
@register_model("local-completions")
class LocalCompletionsAPI(TemplateAPI):
lintangsutawika's avatar
lintangsutawika committed
13
14
    def __init__(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
15
16
17
18
19
20
21
        base_url=None,
        tokenizer_backend="huggingface",
        **kwargs,
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
lintangsutawika's avatar
lintangsutawika committed
22

Baber Abbasi's avatar
Baber Abbasi committed
23
24
25
26
27
    def _create_payload(
        self,
        messages: Union[List[List[int]], List[dict], List[str], str],
        generate=False,
        gen_kwargs: Optional[dict] = None,
28
        seed: int = 1234,
29
        eos=None,
Baber Abbasi's avatar
Baber Abbasi committed
30
31
32
33
        **kwargs,
    ) -> dict:
        if generate:
            gen_kwargs.pop("do_sample", False)
34
35
36
37
            if "max_tokens" in gen_kwargs:
                max_tokens = gen_kwargs.pop("max_tokens")
            else:
                max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
Baber Abbasi's avatar
Baber Abbasi committed
38
            temperature = gen_kwargs.pop("temperature", 0)
39
            stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
Baber Abbasi's avatar
Baber Abbasi committed
40
41
42
43
44
45
            return {
                "prompt": messages,
                "model": self.model,
                "max_tokens": max_tokens,
                "temperature": temperature,
                "stop": stop,
46
                "seed": seed,
Baber Abbasi's avatar
Baber Abbasi committed
47
48
                **gen_kwargs,
            }
Baber Abbasi's avatar
Baber Abbasi committed
49
        else:
Baber Abbasi's avatar
Baber Abbasi committed
50
51
52
            return {
                "model": self.model,
                "prompt": messages,
53
                "temperature": 0,
Baber Abbasi's avatar
Baber Abbasi committed
54
55
                "max_tokens": 1,
                "logprobs": 1,
56
                "seed": seed,
Baber Abbasi's avatar
Baber Abbasi committed
57
58
59
60
61
62
63
64
65
                "echo": True,
            }

    @staticmethod
    def parse_logprobs(
        outputs: Union[Dict, List[Dict]],
        tokens: List[List[int]] = None,
        ctxlens: List[int] = None,
        **kwargs,
lintangsutawika's avatar
lintangsutawika committed
66
67
    ) -> List[Tuple[float, bool]]:
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
68
69
70
71
72
73
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choice, ctxlen in zip(out["choices"], ctxlens):
                assert ctxlen > 0, "Context length must be greater than 0"
                logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
74
                tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
Baber Abbasi's avatar
Baber Abbasi committed
75
76
                top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1]
                is_greedy = True
77
78
                for tok, top in zip(tokens_logprobs, top_logprobs):
                    if tok != max(top.values()):
Baber Abbasi's avatar
Baber Abbasi committed
79
80
81
82
83
84
85
                        is_greedy = False
                        break
                res.append((logprobs, is_greedy))
        return res

    @staticmethod
    def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
lintangsutawika's avatar
lintangsutawika committed
86
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
87
88
89
90
91
92
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choices in out["choices"]:
                res.append(choices["text"])
        return res
lintangsutawika's avatar
lintangsutawika committed
93

Baber Abbasi's avatar
Baber Abbasi committed
94
95
96
    @property
    def api_key(self):
        return os.environ.get("OPENAI_API_KEY", "")
lintangsutawika's avatar
lintangsutawika committed
97
98


Baber Abbasi's avatar
Baber Abbasi committed
99
100
101
102
103
104
105
106
107
@register_model("local-chat-completions")
class LocalChatCompletion(LocalCompletionsAPI):
    def __init__(
        self,
        base_url=None,
        tokenizer_backend=None,
        tokenized_requests=False,
        **kwargs,
    ):
108
109
110
        eval_logger.warning(
            "chat-completions endpoint requires the `--apply_chat_template` flag."
        )
Baber Abbasi's avatar
Baber Abbasi committed
111
112
113
114
115
116
117
118
119
        super().__init__(
            base_url=base_url,
            tokenizer_backend=tokenizer_backend,
            tokenized_requests=tokenized_requests,
            **kwargs,
        )
        if self._batch_size > 1:
            eval_logger.warning(
                "Chat completions does not support batching. Defaulting to batch size 1."
lintangsutawika's avatar
lintangsutawika committed
120
            )
Baber Abbasi's avatar
Baber Abbasi committed
121
122
123
            self._batch_size = 1

    def _create_payload(
124
125
126
127
128
        self,
        messages: List[Dict],
        generate=False,
        gen_kwargs: dict = None,
        seed=1234,
129
        eos=None,
130
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
131
    ) -> dict:
132
133
134
        assert (
            type(messages) is not str
        ), "chat-completions require the --apply_chat_template flag."
Baber Abbasi's avatar
Baber Abbasi committed
135
        gen_kwargs.pop("do_sample", False)
136
137
138
139
        if "max_tokens" in gen_kwargs:
            max_tokens = gen_kwargs.pop("max_tokens")
        else:
            max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
Baber Abbasi's avatar
Baber Abbasi committed
140
        temperature = gen_kwargs.pop("temperature", 0)
141
        stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
Baber Abbasi's avatar
Baber Abbasi committed
142
143
144
145
146
147
148
149
        if not isinstance(stop, (list, tuple)):
            stop = [stop]
        return {
            "messages": messages,
            "model": self.model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stop": stop[:4],
150
            "seed": seed,
Baber Abbasi's avatar
Baber Abbasi committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            **gen_kwargs,
        }

    @staticmethod
    def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
        res = []
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choices in out["choices"]:
                res.append(choices["message"]["content"])
        return res

    def tok_encode(
        self,
        string: Union[str, Any],
        left_truncate_len=None,
        add_special_tokens=None,
        **kwargs,
    ) -> Union[List[str], List[int], Any]:
        return string
lintangsutawika's avatar
lintangsutawika committed
172

Baber Abbasi's avatar
Baber Abbasi committed
173
    def loglikelihood(self, requests, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
174
175
176
        raise NotImplementedError(
            "Loglikelihood is not supported for chat completions. Consider using the completions API instead."
        )
lintangsutawika's avatar
lintangsutawika committed
177
178


Baber Abbasi's avatar
Baber Abbasi committed
179
180
181
182
@register_model(
    "openai-completions",
)
class OpenAICompletionsAPI(LocalCompletionsAPI):
183
    def __init__(
184
        self,
Baber Abbasi's avatar
Baber Abbasi committed
185
186
        base_url="https://api.openai.com/v1/completions",
        tokenizer_backend="tiktoken",
187
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
188
189
190
191
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
192

Baber Abbasi's avatar
Baber Abbasi committed
193
194
195
196
197
198
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("OPENAI_API_KEY", None)
        if key is None:
            raise ValueError(
199
                "API key not found. Please set the `OPENAI_API_KEY` environment variable."
200
            )
Baber Abbasi's avatar
Baber Abbasi committed
201
        return key
202

Baber Abbasi's avatar
Baber Abbasi committed
203
    def loglikelihood(self, requests, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
204
        assert (
205
206
207
208
209
210
            self.model
            in [
                "babbage-002",
                "davinci-002",
            ]
        ), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
Baber Abbasi's avatar
Baber Abbasi committed
211
        return super().loglikelihood(requests, **kwargs)
212

213
214
215
    def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
        return ""

216

Baber Abbasi's avatar
Baber Abbasi committed
217
@register_model("openai-chat-completions")
Baber Abbasi's avatar
Baber Abbasi committed
218
219
220
221
222
223
224
225
class OpenAIChatCompletion(LocalChatCompletion):
    def __init__(
        self,
        base_url="https://api.openai.com/v1/chat/completions",
        tokenizer_backend=None,
        tokenized_requests=False,
        **kwargs,
    ):
226
227
228
229
        if "o1" in kwargs.get("model", ""):
            eval_logger.warning(
                "o1 models do not support `stop` and only support temperature=1"
            )
Baber Abbasi's avatar
Baber Abbasi committed
230
231
232
233
234
235
        super().__init__(
            base_url=base_url,
            tokenizer_backend=tokenizer_backend,
            tokenized_requests=tokenized_requests,
            **kwargs,
        )
236

Baber Abbasi's avatar
Baber Abbasi committed
237
238
239
240
241
242
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("OPENAI_API_KEY", None)
        if key is None:
            raise ValueError(
243
                "API key not found. Please set the `OPENAI_API_KEY` environment variable."
244
            )
Baber Abbasi's avatar
Baber Abbasi committed
245
        return key
246
247
248
249
250

    def loglikelihood(self, requests, **kwargs):
        raise NotImplementedError(
            "Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
        )
251
252
253
254
255
256
257

    def _create_payload(
        self,
        messages: List[Dict],
        generate=False,
        gen_kwargs: dict = None,
        seed=1234,
258
        eos="<|endoftext|>",
259
260
261
262
263
264
265
266
267
268
269
        **kwargs,
    ) -> dict:
        assert (
            type(messages) is not str
        ), "chat-completions require the --apply_chat_template flag."
        gen_kwargs.pop("do_sample", False)
        if "max_tokens" in gen_kwargs:
            max_tokens = gen_kwargs.pop("max_tokens")
        else:
            max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
        temperature = gen_kwargs.pop("temperature", 0)
270
        stop = handle_stop_sequences(gen_kwargs.pop("until", ["<|endoftext|>"]), eos)
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        if not isinstance(stop, (list, tuple)):
            stop = [stop]
        output = {
            "messages": messages,
            "model": self.model,
            "max_completion_tokens": max_tokens,
            "temperature": temperature,
            "stop": stop[:4],
            "seed": seed,
            **gen_kwargs,
        }
        if "o1" in self.model:
            output.pop("stop")
            output["temperature"] = 1
        return output