gpt3.py 7.4 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Jason Phang's avatar
Jason Phang committed
2
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
4
from lm_eval.base import LM, TokenizedLM
Jason Phang's avatar
lib  
Jason Phang committed
5
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
6
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
7
import time
Leo Gao's avatar
Leo Gao committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


def get_result(response, ctxlen):
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
    
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
24
25


Leo Gao's avatar
Leo Gao committed
26
27
28
29
30
31
32
33
34
35
36
37
def oa_completion(**kwargs):
    import openai

    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
            time.sleep(backoff_time)
            backoff_time *= 1.5


38
class GPT3LM(TokenizedLM):
Leo Gao's avatar
Leo Gao committed
39
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
40
41
42
43
44
45
46
47
48

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
49
        super().__init__()
50

Jason Phang's avatar
Jason Phang committed
51
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
52
        self.engine = engine
53
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
54

55
56
57
58
        self.vocab_size = self.tokenizer.vocab_size
        self.eot_token_id = self.tokenizer.eos_token_id
        self.max_gen_toks = 256
        self.max_length = 2048
Leo Gao's avatar
Leo Gao committed
59

Leo Gao's avatar
Leo Gao committed
60
61
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
62
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
63
        self.truncate = truncate
Jason Phang's avatar
Jason Phang committed
64
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
Jason Phang's avatar
Jason Phang committed
65

Jason Phang's avatar
gpt3  
Jason Phang committed
66
67
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
68
69
70
71
72
73
    
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
    
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
74

Leo Gao's avatar
Leo Gao committed
75
    def loglikelihood_rolling(self, requests):
Leo Gao's avatar
Leo Gao committed
76
        # TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing
Jason Phang's avatar
Jason Phang committed
77
78
79
80
81
82
83

        loglikelihoods = []
        for string, in tqdm(requests):
            encoded = self.tokenizer.encode_plus(string)["input_ids"]
            rolling_token_windows = utils.get_rolling_token_windows(
                token_list=encoded,
                prefix_token=self.end_of_text_token_id,
84
                max_seq_len=self.max_length,
Jason Phang's avatar
Jason Phang committed
85
86
87
88
89
90
91
92
93
                context_len=1,
            )
            string_loglikelihoods = []
            for input_tokens, pred_tokens in rolling_token_windows:
                block_output = self.get_token_logprobs(
                    input_tokens=input_tokens,
                    pred_tokens=pred_tokens,
                )
                string_loglikelihoods.append(block_output["logprobs"])
Leo Gao's avatar
Leo Gao committed
94
            string_loglikelihoods = np.concatenate(string_loglikelihoods).sum()
Jason Phang's avatar
Jason Phang committed
95
96
97
98
            loglikelihoods.append(string_loglikelihoods)

        return loglikelihoods

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    def get_token_logprobs(self, input_tokens, pred_tokens):
        pred_start = len(input_tokens) - len(pred_tokens) + 1
        # We're going to stitch together the input_tokens and pred_tokens
        # In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
        assert input_tokens[pred_start:] == pred_tokens[:-1]
        token_ids = input_tokens + [pred_tokens[-1]]
        response = oa_completion(
            engine=self.engine,
            prompt=token_ids,
            max_tokens=0,
            temperature=0.0,
            logprobs=0,
            echo=True,
        )
        logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
        positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
        return {
            "logprobs": logprobs,
            "positions": positions,
        }

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
121
122
        res = []

123
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
124
125
126
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
127
            toks = x[1] + x[2]
128
            return (-len(toks), tuple(toks))
129
130
        
        reord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
131

132
        for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
Leo Gao's avatar
Leo Gao committed
133
134
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
135
            for cache_key, context_enc, continuation_enc in chunk:
136
137
                inp = (context_enc + continuation_enc)[-self.max_length:]
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
Leo Gao's avatar
Leo Gao committed
138
139
140
141

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
142
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
143
144
145
146
147
148
149
                engine=self.engine,
                prompt=inps,
                echo=True,
                max_tokens=0, temperature=0.,
                logprobs=10,
            )

Leo Gao's avatar
Leo Gao committed
150
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
Leo Gao's avatar
Leo Gao committed
151
152
153
154
155
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
156
157
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
158

159
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
160
161

    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
162
        if not requests: return []
Leo Gao's avatar
Leo Gao committed
163
164
165
        import openai
        res = []

166
        def _collate(x):
167
            toks = self.tok_encode(x[0])
168
169
170
171
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

Leo Gao's avatar
Leo Gao committed
172
173
174
175
176
177
178
179
180
181
182
183
184
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
            
            if ret: yield ret, lastuntil

        # todo: more intelligent batching for heterogenous `until`
185
        for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
186
187
            inps = []
            for context, _ in chunk:
188
189
                context_enc = self.tok_encode(context)
                inp = context_enc[-(self.max_length - self.max_gen_toks):]
Leo Gao's avatar
Leo Gao committed
190
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
191

Leo Gao's avatar
Leo Gao committed
192
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
193
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
194
                prompt=inps,
195
                max_tokens=self.max_gen_toks, 
Leo Gao's avatar
Leo Gao committed
196
197
                temperature=0.,
                logprobs=10,
Leo Gao's avatar
Leo Gao committed
198
                stop=until
Leo Gao's avatar
Leo Gao committed
199
            )
Leo Gao's avatar
Leo Gao committed
200

Leo Gao's avatar
Leo Gao committed
201
            for resp, (context, until) in zip(response.choices, chunk):
Leo Gao's avatar
Leo Gao committed
202
                s = resp['text']
Leo Gao's avatar
Leo Gao committed
203
204
205

                for term in until:
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
206

Leo Gao's avatar
Leo Gao committed
207
208
209
                # partial caching
                self.cache_hook.add_partial("greedy_until", (context, until), s)
                
Leo Gao's avatar
Leo Gao committed
210
                res.append(s)
Leo Gao's avatar
Leo Gao committed
211
        
Leo Gao's avatar
Leo Gao committed
212
        return reord.get_original(res)