openai_completions.py 7.04 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
baberabb's avatar
baberabb committed
3
from typing import List, Tuple
lintangsutawika's avatar
update  
lintangsutawika committed
4
5
6

import copy
from collections import defaultdict
Leo Gao's avatar
Leo Gao committed
7
from tqdm import tqdm
lintangsutawika's avatar
update  
lintangsutawika committed
8

lintangsutawika's avatar
lintangsutawika committed
9
from lm_eval import utils
10
11
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Leo Gao's avatar
Leo Gao committed
12

lintangsutawika's avatar
lintangsutawika committed
13
from openai import OpenAI
Leo Gao's avatar
Leo Gao committed
14

lintangsutawika's avatar
lintangsutawika committed
15
client = OpenAI()
Leo Gao's avatar
Leo Gao committed
16

lintangsutawika's avatar
update  
lintangsutawika committed
17

18
def oa_chat_completion(**kwargs):
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    """Query OpenAI API for chat completion.

    Retry with back-off until they respond
    """
    try:
        import openai, tiktoken  # noqa: E401
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
        )

    backoff_time = 3
    while True:
        try:
lintangsutawika's avatar
lintangsutawika committed
34
35
            return client.chat.completions.create(**kwargs)
        except openai.OpenAIError:
36
37
38
39
40
41
42
            import traceback

            traceback.print_exc()
            time.sleep(backoff_time)
            backoff_time *= 1.5


43
44
45
46
@register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(LM):
    REQ_CHUNK_SIZE = 20

47
    def __init__(
lintangsutawika's avatar
update  
lintangsutawika committed
48
        self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
49
    ) -> None:
50
51
        """

lintangsutawika's avatar
lintangsutawika committed
52
53
        :param model: str
            OpenAI API model (e.g. gpt-3.5-turbo)
54
55
56
57
58
59
60
61
62
63
64
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        super().__init__()
        try:
            import openai, tiktoken  # noqa: E401
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
lintangsutawika's avatar
lintangsutawika committed
65
66
67
68
69
70
71
72
        self.model = model
        self.frequency_penalty = 0
        self.logit_bias = None
        self.n = 1
        self.presence_penalty = 0
        self.temperature = 1
        self.top_p = 1
        self.tokenizer = tiktoken.encoding_for_model(self.model)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        self.vocab_size = self.tokenizer.n_vocab
        self.truncate = truncate
        self.end_of_text_token_id = self.tokenizer.eot_token

        # Read from environment variable OPENAI_API_SECRET_KEY

    @property
    def eot_token_id(self):
        return self.end_of_text_token_id

    @property
    def max_length(self) -> int:
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self) -> int:
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string)

    def tok_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)

    def _encode_pair(
lintangsutawika's avatar
update  
lintangsutawika committed
109
        self, context: str, continuation: str
110
111
112
113
114
115
116
117
118
119
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc
120

121
    def generate_until(self, requests) -> List[str]:
lintangsutawika's avatar
update  
lintangsutawika committed
122
123
        res = defaultdict(list)
        re_ords = {}
124
125
126

        def _collate(x):
            toks = self.tok_encode(x[0])
lintangsutawika's avatar
update  
lintangsutawika committed
127
            return -len(toks), x[0]
128

lintangsutawika's avatar
update  
lintangsutawika committed
129
130
131
132
133
134
135
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
        for key, reqs in grouper.get_grouped().items():
            # within each set of reqs for given kwargs, we reorder by token length, descending.
            re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)

            if ret:
                yield ret, lastuntil

lintangsutawika's avatar
update  
lintangsutawika committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        pbar = tqdm(total=len(requests), disable=(self.rank != 0))
        for key, re_ord in re_ords.items():
            chunks = utils.chunks(re_ord.get_reordered(), n=self.REQ_CHUNK_SIZE)
            for chunk in chunks:
                contexts, all_gen_kwargs = zip(*chunk)
                inps = [{"role": "user", "content": context} for context in contexts]

            gen_kwargs = all_gen_kwargs[0]
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
                        until = [kwargs]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {kwargs}"
                )
173

lintangsutawika's avatar
update  
lintangsutawika committed
174
175
176
177
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks
178

179
            response = oa_chat_completion(
lintangsutawika's avatar
lintangsutawika committed
180
181
182
183
                messages=inps,
                model=self.model,
                frequency_penalty=self.frequency_penalty,
                # logit_bias=self.logit_bias,
lintangsutawika's avatar
update  
lintangsutawika committed
184
                max_tokens=max_gen_toks,
lintangsutawika's avatar
lintangsutawika committed
185
186
187
188
                n=self.n,
                presence_penalty=self.presence_penalty,
                temperature=self.temperature,
                top_p=self.top_p,
189
190
191
            )

            for resp, (context, args_) in zip(response.choices, chunk):
lintangsutawika's avatar
lintangsutawika committed
192
                s = resp.message.content
193

lintangsutawika's avatar
update  
lintangsutawika committed
194
195
                if until is not None:
                    for term in until:
lintangsutawika's avatar
lintangsutawika committed
196
197
                        if len(term) > 0:
                            s = s.split(term)[0]
198

lintangsutawika's avatar
update  
lintangsutawika committed
199
200
                res[key].append(s)

201
                self.cache_hook.add_partial(
lintangsutawika's avatar
update  
lintangsutawika committed
202
                    "generate_until", (context, {"until": until}), s
203
                )
lintangsutawika's avatar
update  
lintangsutawika committed
204
205
206
207
208
                pbar.update(1)

            res[key] = re_ord.get_original(res[key])

        pbar.close()
209

lintangsutawika's avatar
update  
lintangsutawika committed
210
        return grouper.get_original(res)
211
212
213
214
215
216

    def loglikelihood(self, requests):
        raise NotImplementedError("No support for logits.")

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError("No support for logits.")