openai_completions.py 6.05 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
baberabb's avatar
baberabb committed
3
from typing import List, Tuple
Leo Gao's avatar
Leo Gao committed
4
from tqdm import tqdm
lintangsutawika's avatar
lintangsutawika committed
5
from lm_eval import utils
6
7
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Leo Gao's avatar
Leo Gao committed
8

lintangsutawika's avatar
lintangsutawika committed
9
from openai import OpenAI
Leo Gao's avatar
Leo Gao committed
10

lintangsutawika's avatar
lintangsutawika committed
11
client = OpenAI()
Leo Gao's avatar
Leo Gao committed
12

13
def oa_chat_completion(**kwargs):
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    """Query OpenAI API for chat completion.

    Retry with back-off until they respond
    """
    try:
        import openai, tiktoken  # noqa: E401
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
        )

    backoff_time = 3
    while True:
        try:
lintangsutawika's avatar
lintangsutawika committed
29
30
            return client.chat.completions.create(**kwargs)
        except openai.OpenAIError:
31
32
33
34
35
36
37
            import traceback

            traceback.print_exc()
            time.sleep(backoff_time)
            backoff_time *= 1.5


38
39
40
41
@register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(LM):
    REQ_CHUNK_SIZE = 20

42
    def __init__(
lintangsutawika's avatar
lintangsutawika committed
43
            self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
44
    ) -> None:
45
46
        """

lintangsutawika's avatar
lintangsutawika committed
47
48
        :param model: str
            OpenAI API model (e.g. gpt-3.5-turbo)
49
50
51
52
53
54
55
56
57
58
59
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        super().__init__()
        try:
            import openai, tiktoken  # noqa: E401
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
lintangsutawika's avatar
lintangsutawika committed
60
61
62
63
64
65
66
67
        self.model = model
        self.frequency_penalty = 0
        self.logit_bias = None
        self.n = 1
        self.presence_penalty = 0
        self.temperature = 1
        self.top_p = 1
        self.tokenizer = tiktoken.encoding_for_model(self.model)
68
69
70
71
72
        self.vocab_size = self.tokenizer.n_vocab
        self.truncate = truncate
        self.end_of_text_token_id = self.tokenizer.eot_token

        # Read from environment variable OPENAI_API_SECRET_KEY
lintangsutawika's avatar
lintangsutawika committed
73
        
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    @property
    def eot_token_id(self):
        return self.end_of_text_token_id

    @property
    def max_length(self) -> int:
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self) -> int:
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string)

    def tok_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)

    def _encode_pair(
            self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc
116

117
    def generate_until(self, requests) -> List[str]:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        if not requests:
            return []
        res = []
        requests = [req.args for req in requests]

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer(requests, _collate)

        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)

            if ret:
                yield ret, lastuntil

        # todo: more intelligent batching for heterogeneous `until`
        for chunk, request_args in tqdm(
144
                list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
145
146
147
        ):
            inps = []
            for context, _ in chunk:
lintangsutawika's avatar
lintangsutawika committed
148
149
150
                # context_enc = self.tok_encode(context)
                # inp = context_enc[-(self.max_length - self.max_gen_toks):]
                inps.append({"role": "user", "content": context})
151

lintangsutawika's avatar
lintangsutawika committed
152
153
            # until = request_args.get("until", ["<|endoftext|>"])
            until = request_args.get("until", None)
154

155
            response = oa_chat_completion(
lintangsutawika's avatar
lintangsutawika committed
156
157
158
159
                messages=inps,
                model=self.model,
                frequency_penalty=self.frequency_penalty,
                # logit_bias=self.logit_bias,
160
                max_tokens=self.max_gen_toks,
lintangsutawika's avatar
lintangsutawika committed
161
162
163
164
165
                n=self.n,
                presence_penalty=self.presence_penalty,
                temperature=self.temperature,
                top_p=self.top_p,
                # stop=until,
166
167
168
            )

            for resp, (context, args_) in zip(response.choices, chunk):
lintangsutawika's avatar
lintangsutawika committed
169
170
171
                print(resp)
                import sys; sys.exit()

172
173
                s = resp["text"]

lintangsutawika's avatar
lintangsutawika committed
174
175
                # until_ = args_.get("until", ["<|endoftext|>"])
                until_ = args_.get("until", "null")
176
177
178
179
180
181
182

                for term in until_:
                    if len(term) > 0:
                        s = s.split(term)[0]

                # partial caching
                self.cache_hook.add_partial(
183
                    "generate_until", (context, {"until": until_}), s
184
185
186
187
                )

                res.append(s)
        return re_ord.get_original(res)
188
189
190
191
192
193

    def loglikelihood(self, requests):
        raise NotImplementedError("No support for logits.")

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError("No support for logits.")