gpt3.py 7.4 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Jason Phang's avatar
Jason Phang committed
2
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
4
from lm_eval.base import BaseLM
Jason Phang's avatar
lib  
Jason Phang committed
5
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
6
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
7
import time
Leo Gao's avatar
Leo Gao committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


def get_result(response, ctxlen):
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
    
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
24
25


Leo Gao's avatar
Leo Gao committed
26
def oa_completion(**kwargs):
27
    """ Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
28

29
30
31
    Retry with back-off until they respond
    """
    import openai
Leo Gao's avatar
Leo Gao committed
32
33
34
35
36
37
38
39
40
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
            time.sleep(backoff_time)
            backoff_time *= 1.5


Leo Gao's avatar
Leo Gao committed
41
class GPT3LM(BaseLM):
Leo Gao's avatar
Leo Gao committed
42
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
43
44
45
46
47
48
49
50
51

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
52
        super().__init__()
53

Jason Phang's avatar
Jason Phang committed
54
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
55
        self.engine = engine
56
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
57

58
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
59

Leo Gao's avatar
Leo Gao committed
60
61
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
62
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
63
        self.truncate = truncate
Jason Phang's avatar
Jason Phang committed
64
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
Jason Phang's avatar
Jason Phang committed
65

Jason Phang's avatar
gpt3  
Jason Phang committed
66
67
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

92
93
94
95
96
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
    
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    def get_token_logprobs(self, input_tokens, pred_tokens):
        pred_start = len(input_tokens) - len(pred_tokens) + 1
        # We're going to stitch together the input_tokens and pred_tokens
        # In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
        assert input_tokens[pred_start:] == pred_tokens[:-1]
        token_ids = input_tokens + [pred_tokens[-1]]
        response = oa_completion(
            engine=self.engine,
            prompt=token_ids,
            max_tokens=0,
            temperature=0.0,
            logprobs=0,
            echo=True,
        )
        logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
        positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
        return {
            "logprobs": logprobs,
            "positions": positions,
        }

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
120
121
        res = []

122
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
123
124
125
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
126
            toks = x[1] + x[2]
127
            return -len(toks), tuple(toks)
128
129
        
        reord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
130

131
        for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
Leo Gao's avatar
Leo Gao committed
132
133
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
134
            for cache_key, context_enc, continuation_enc in chunk:
135
136
137
138
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
                inp = (context_enc + continuation_enc)[-(self.max_length+1):]
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
Leo Gao's avatar
Leo Gao committed
139
140
141
142

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
143
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
144
145
146
147
148
149
150
                engine=self.engine,
                prompt=inps,
                echo=True,
                max_tokens=0, temperature=0.,
                logprobs=10,
            )

Leo Gao's avatar
Leo Gao committed
151
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
Leo Gao's avatar
Leo Gao committed
152
153
154
155
156
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
157
158
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
159

160
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
161
162

    def greedy_until(self, requests):
163
164
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
165
166
        res = []

167
        def _collate(x):
168
            toks = self.tok_encode(x[0])
169
            return len(toks), x[0]
170
171
172
        
        reord = utils.Reorderer(requests, _collate)

Leo Gao's avatar
Leo Gao committed
173
174
175
176
177
178
179
180
181
182
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
            
183
184
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
185

186
        # todo: more intelligent batching for heterogeneous `until`
187
        for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
188
189
            inps = []
            for context, _ in chunk:
190
191
                context_enc = self.tok_encode(context)
                inp = context_enc[-(self.max_length - self.max_gen_toks):]
Leo Gao's avatar
Leo Gao committed
192
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
193

Leo Gao's avatar
Leo Gao committed
194
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
195
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
196
                prompt=inps,
197
                max_tokens=self.max_gen_toks, 
Leo Gao's avatar
Leo Gao committed
198
199
                temperature=0.,
                logprobs=10,
200
                stop=until,
Leo Gao's avatar
Leo Gao committed
201
            )
Leo Gao's avatar
Leo Gao committed
202

203
            for resp, (context, until_) in zip(response.choices, chunk):
Leo Gao's avatar
Leo Gao committed
204
                s = resp['text']
Leo Gao's avatar
Leo Gao committed
205

206
                for term in until_:
Leo Gao's avatar
Leo Gao committed
207
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
208

Leo Gao's avatar
Leo Gao committed
209
                # partial caching
210
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
Leo Gao's avatar
Leo Gao committed
211
                
Leo Gao's avatar
Leo Gao committed
212
                res.append(s)
Leo Gao's avatar
Leo Gao committed
213
        
Leo Gao's avatar
Leo Gao committed
214
        return reord.get_original(res)
215
216
217
218
219
220
221
222

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()