openai_completions.py 6.59 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Baber Abbasi's avatar
Baber Abbasi committed
2
3
from functools import cached_property
from typing import Any, Dict, List, Optional, Tuple, Union
4

5
from lm_eval.api.registry import register_model
Baber Abbasi's avatar
Baber Abbasi committed
6
from lm_eval.models.api_models import TemplateAPI
7
from lm_eval.utils import eval_logger
Leo Gao's avatar
Leo Gao committed
8

lintangsutawika's avatar
update  
lintangsutawika committed
9

Baber Abbasi's avatar
Baber Abbasi committed
10
11
@register_model("local-completions")
class LocalCompletionsAPI(TemplateAPI):
lintangsutawika's avatar
lintangsutawika committed
12
13
    def __init__(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
14
15
16
17
18
19
20
        base_url=None,
        tokenizer_backend="huggingface",
        **kwargs,
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
lintangsutawika's avatar
lintangsutawika committed
21

Baber Abbasi's avatar
Baber Abbasi committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    def _create_payload(
        self,
        messages: Union[List[List[int]], List[dict], List[str], str],
        generate=False,
        gen_kwargs: Optional[dict] = None,
        **kwargs,
    ) -> dict:
        if generate:
            gen_kwargs.pop("do_sample", False)
            max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
            temperature = gen_kwargs.pop("temperature", 0)
            stop = gen_kwargs.pop("until", ["<|endoftext|>"])
            return {
                "prompt": messages,
                "model": self.model,
                "max_tokens": max_tokens,
                "temperature": temperature,
                "stop": stop,
                **gen_kwargs,
            }
Baber Abbasi's avatar
Baber Abbasi committed
42
        else:
Baber Abbasi's avatar
Baber Abbasi committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
            return {
                "model": self.model,
                "prompt": messages,
                "max_tokens": 1,
                "logprobs": 1,
                "echo": True,
            }

    @staticmethod
    def parse_logprobs(
        outputs: Union[Dict, List[Dict]],
        tokens: List[List[int]] = None,
        ctxlens: List[int] = None,
        **kwargs,
lintangsutawika's avatar
lintangsutawika committed
57
58
    ) -> List[Tuple[float, bool]]:
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choice, ctxlen in zip(out["choices"], ctxlens):
                assert ctxlen > 0, "Context length must be greater than 0"
                logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
                tokens = choice["logprobs"]["token_logprobs"][ctxlen:-1]
                top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1]
                is_greedy = True
                for tok, top in zip(tokens, top_logprobs):
                    if tok != max(top, key=top.get):
                        is_greedy = False
                        break
                res.append((logprobs, is_greedy))
        return res

    @staticmethod
    def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
lintangsutawika's avatar
lintangsutawika committed
77
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
78
79
80
81
82
83
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choices in out["choices"]:
                res.append(choices["text"])
        return res
lintangsutawika's avatar
lintangsutawika committed
84

Baber Abbasi's avatar
Baber Abbasi committed
85
86
87
    @property
    def api_key(self):
        return os.environ.get("OPENAI_API_KEY", "")
lintangsutawika's avatar
lintangsutawika committed
88
89


Baber Abbasi's avatar
Baber Abbasi committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
@register_model("local-chat-completions")
class LocalChatCompletion(LocalCompletionsAPI):
    def __init__(
        self,
        base_url=None,
        tokenizer_backend=None,
        tokenized_requests=False,
        **kwargs,
    ):
        super().__init__(
            base_url=base_url,
            tokenizer_backend=tokenizer_backend,
            tokenized_requests=tokenized_requests,
            **kwargs,
        )
        if self._batch_size > 1:
            eval_logger.warning(
                "Chat completions does not support batching. Defaulting to batch size 1."
lintangsutawika's avatar
lintangsutawika committed
108
            )
Baber Abbasi's avatar
Baber Abbasi committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
            self._batch_size = 1

    def _create_payload(
        self, messages: List[Dict], generate=False, gen_kwargs: dict = None, **kwargs
    ) -> dict:
        gen_kwargs.pop("do_sample", False)
        max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
        temperature = gen_kwargs.pop("temperature", 0)
        stop = gen_kwargs.pop("until", ["<|endoftext|>"])
        if not isinstance(stop, (list, tuple)):
            stop = [stop]
        return {
            "messages": messages,
            "model": self.model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stop": stop[:4],
            **gen_kwargs,
        }

    @staticmethod
    def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
        res = []
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choices in out["choices"]:
                res.append(choices["message"]["content"])
        return res

    def tok_encode(
        self,
        string: Union[str, Any],
        left_truncate_len=None,
        add_special_tokens=None,
        **kwargs,
    ) -> Union[List[str], List[int], Any]:
        return string
lintangsutawika's avatar
lintangsutawika committed
147

Baber Abbasi's avatar
Baber Abbasi committed
148
149
150
151
    def _loglikelihood_tokens(self, requests, **kwargs):
        raise NotImplementedError(
            "Loglikelihood is not supported for chat completions. Consider using the completions API instead."
        )
lintangsutawika's avatar
lintangsutawika committed
152
153


Baber Abbasi's avatar
Baber Abbasi committed
154
155
156
157
@register_model(
    "openai-completions",
)
class OpenAICompletionsAPI(LocalCompletionsAPI):
158
    def __init__(
159
        self,
Baber Abbasi's avatar
Baber Abbasi committed
160
161
        base_url="https://api.openai.com/v1/completions",
        tokenizer_backend="tiktoken",
162
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
163
164
165
166
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
167

Baber Abbasi's avatar
Baber Abbasi committed
168
169
170
171
172
173
174
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("OPENAI_API_KEY", None)
        if key is None:
            raise ValueError(
                "API key not found. Please set the OPENAI_API_KEY environment variable."
175
            )
Baber Abbasi's avatar
Baber Abbasi committed
176
        return key
177

Baber Abbasi's avatar
Baber Abbasi committed
178
179
180
181
182
    def _loglikelihood_tokens(self, requests, **kwargs):
        assert (
            self.model != "gpt-3.5-turbo"
        ), "Loglikelihood is not supported for gpt-3.5-turbo"
        return super()._loglikelihood_tokens(requests, **kwargs)
183
184


Baber Abbasi's avatar
Baber Abbasi committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
@register_model("openai-chatcompletions")
class OpenAIChatCompletion(LocalChatCompletion):
    def __init__(
        self,
        base_url="https://api.openai.com/v1/chat/completions",
        tokenizer_backend=None,
        tokenized_requests=False,
        **kwargs,
    ):
        super().__init__(
            base_url=base_url,
            tokenizer_backend=tokenizer_backend,
            tokenized_requests=tokenized_requests,
            **kwargs,
        )
200

Baber Abbasi's avatar
Baber Abbasi committed
201
202
203
204
205
206
207
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("OPENAI_API_KEY", None)
        if key is None:
            raise ValueError(
                "API key not found. Please set the OPENAI_API_KEY environment variable."
208
            )
Baber Abbasi's avatar
Baber Abbasi committed
209
        return key