gpt3.py 10.1 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Jason Phang's avatar
Jason Phang committed
2
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
4
from lm_eval.base import BaseLM
Jason Phang's avatar
lib  
Jason Phang committed
5
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
6
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
7
import time
Leo Gao's avatar
Leo Gao committed
8
9
10


def get_result(response, ctxlen):
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
    is_greedy = True
Leo Gao's avatar
Leo Gao committed
24
    logprobs = response["logprobs"]["token_logprobs"][:-1]
Leo Gao's avatar
Leo Gao committed
25
    continuation_logprobs = sum(logprobs[ctxlen:])
Leo Gao's avatar
Leo Gao committed
26
    # print(logprobs[ctxlen:])
Leo Gao's avatar
Leo Gao committed
27

Leo Gao's avatar
Leo Gao committed
28
29
30
    for i in range(ctxlen, len(response["logprobs"]["tokens"][:-1])):
        token = response["logprobs"]["tokens"][:-1][i]
        top_tokens = response["logprobs"]["top_logprobs"][:-1][i]
Leo Gao's avatar
Leo Gao committed
31
32
33
34
35
36
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
    
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
37
38


Leo Gao's avatar
Leo Gao committed
39
40
41
class _goose:
    choices: list

Leo Gao's avatar
Leo Gao committed
42
def oa_completion(**kwargs):
43
    """ Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
44

45
46
47
    Retry with back-off until they respond
    """
    import openai
Leo Gao's avatar
Leo Gao committed
48
    backoff_time = 3
Leo Gao's avatar
Leo Gao committed
49
50
51
52
53
54
55
56
57
58
59
60
    # print(kwargs)
    if len(kwargs["prompt"]) > 1 and isinstance(kwargs["prompt"], list):
        import dask
        res = []
        for pmpt in kwargs["prompt"]:
            k = kwargs.copy()
            k["prompt"] = [pmpt]
            res.append(dask.delayed(oa_completion)(**k))
        r = dask.compute(*res)
        ob = _goose()
        ob.choices = [x.choices[0] for x in r]

Leo Gao's avatar
Leo Gao committed
61
62
    while True:
        try:
Leo Gao's avatar
Leo Gao committed
63
64
65
            ret = openai.Completion.create(**kwargs)
            # print(ret.choices[0])
            return ret
Leo Gao's avatar
Leo Gao committed
66
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
67
68
            import traceback
            traceback.print_exc()
69
            traceback.print_exc(file=os.path.join(os.environ["QUESTION_RESULT_PATH"], "traceback.txt"))
Leo Gao's avatar
Leo Gao committed
70
71
72
73
            time.sleep(backoff_time)
            backoff_time *= 1.5


Leo Gao's avatar
Leo Gao committed
74
class GPT3LM(BaseLM):
jon-tow's avatar
jon-tow committed
75
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
76

Leo Gao's avatar
Leo Gao committed
77
    def __init__(self, engine, truncate=False, api_key=None, pass_strings=False):
Jason Phang's avatar
Jason Phang committed
78
79
80
81
82
83
84
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
85
        super().__init__()
86

Leo Gao's avatar
Leo Gao committed
87
88
        assert pass_strings, "so far, this branch only supports GooseAI, and won't work with the regular OpenAI api. this is mostly because there are still some remaining differences between the two apis that make this more complicated than just a drop in replacement. there's no fundamental reason why I couldn't support both on the same branch right now, but it would be a lot of work, and once gooseai finally makes their api conform to the openai api then we won't need this branch anymore and I'll implement something more simple once that does actually happen."

Jason Phang's avatar
Jason Phang committed
89
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
90
        self.engine = engine
91
        print(self.max_length)
92
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
93
        self.pass_strings = pass_strings
Leo Gao's avatar
Leo Gao committed
94

95
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
96

Leo Gao's avatar
Leo Gao committed
97
98
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
99
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
100
        self.truncate = truncate
Jason Phang's avatar
Jason Phang committed
101
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
Jason Phang's avatar
Jason Phang committed
102

Jason Phang's avatar
gpt3  
Jason Phang committed
103
        # Read from environment variable OPENAI_API_SECRET_KEY
Leo Gao's avatar
Leo Gao committed
104
        openai.api_key = api_key or os.environ["OPENAI_API_SECRET_KEY"]
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

129
130
131
132
133
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
    
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
134

135
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
136
137
        res = []

138
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
139
140
141
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
142
            toks = x[1] + x[2]
143
            return -len(toks), tuple(toks)
144
145
        
        reord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
146

147
        for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
Leo Gao's avatar
Leo Gao committed
148
149
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
150
            for cache_key, context_enc, continuation_enc in chunk:
151
152
153
154
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
                inp = (context_enc + continuation_enc)[-(self.max_length+1):]
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
Leo Gao's avatar
Leo Gao committed
155

Leo Gao's avatar
Leo Gao committed
156
                # print(inp)
Leo Gao's avatar
Leo Gao committed
157
158
                if self.pass_strings:
                    inp = self.tok_decode(inp)
Leo Gao's avatar
Leo Gao committed
159
160
                inps.append(inp)
                ctxlens.append(ctxlen)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            response = None
            while True:
                try:
                    response = oa_completion(
                        engine=self.engine,
                        prompt=inps,
                        echo=True,
                        max_tokens=1,
                        logprobs=10,
                    )
                    break
                except Exception as e:
                    print(e)
                    print("pausing")
                    time.sleep(1)
                    continue
Leo Gao's avatar
Leo Gao committed
177

Leo Gao's avatar
Leo Gao committed
178
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
Leo Gao's avatar
Leo Gao committed
179
180
181
182
183
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
184
185
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
186

187
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
188
189

    def greedy_until(self, requests):
190
191
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
192
193
        res = []

194
        def _collate(x):
195
            toks = self.tok_encode(x[0])
196
            return len(toks), x[0]
197
198
199
        
        reord = utils.Reorderer(requests, _collate)

Leo Gao's avatar
Leo Gao committed
200
201
202
203
204
205
206
207
208
209
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
            
210
211
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
212

213
        # todo: more intelligent batching for heterogeneous `until`
214
        for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
215
216
            inps = []
            for context, _ in chunk:
217
                context_enc = self.tok_encode(context, max_length=self.max_length, truncation=False)
218
                inp = context_enc[-(self.max_length - self.max_gen_toks):]
Leo Gao's avatar
Leo Gao committed
219
                inps.append(self.tok_decode(inp))
Leo Gao's avatar
Leo Gao committed
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            response = None
            while True:
                try:

                    response = oa_completion(
                        engine=self.engine,
                        prompt=inps,
                        max_tokens=self.max_gen_toks, 
                        temperature=0.,
                        # logprobs=10,
                        stop=until,
                    )

                    break
                except Exception as e:
                    print(e)
                    print("pausing")
                    time.sleep(1)
                    continue
Leo Gao's avatar
Leo Gao committed
240

241
            for resp, (context, until_) in zip(response.choices, chunk):
Leo Gao's avatar
Leo Gao committed
242
                s = resp['text']
Leo Gao's avatar
Leo Gao committed
243

244
                for term in until_:
Leo Gao's avatar
Leo Gao committed
245
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
246

Leo Gao's avatar
Leo Gao committed
247
                # partial caching
248
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
Leo Gao's avatar
Leo Gao committed
249
                
Leo Gao's avatar
Leo Gao committed
250
                res.append(s)
Leo Gao's avatar
Leo Gao committed
251
        
Leo Gao's avatar
Leo Gao committed
252
        return reord.get_original(res)
253
254
255
256
257
258
259
260

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()
Leo Gao's avatar
Leo Gao committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278


class GooseAILM(GPT3LM):
    def __init__(self, engine, truncate=False, api_key=None, force_pile_tokenizer=False):
        super().__init__(engine, truncate=truncate, api_key=api_key or os.environ["GOOSEAI_API_SECRET_KEY"], pass_strings=True)
        import openai
        openai.api_base = "https://api.goose.ai/v1"

        from best_download import download_file

        if engine == "gpt-neo-20b" or force_pile_tokenizer:
            download_file("http://eaidata.bmk.sh/data/pile_tokenizer.json", expected_checksum="d27f071586925d23ef1c4acdee28fb8bf5d99c4a9d638b4e3b08812e3eae6ee7", local_file="pile_tokenizer.json")
            self.tokenizer = transformers.PreTrainedTokenizerFast(tokenizer_file="pile_tokenizer.json")
        

    @property
    def max_length(self):
        # Note: this is temporary, will be raised to 2048 in the future
jon-tow's avatar
jon-tow committed
279
        return 2022
Leo Gao's avatar
Leo Gao committed
280
281
282
283
284
285
286

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_gen_toks(self):
287
        return 64