gpt3.py 7.11 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Jason Phang's avatar
Jason Phang committed
2
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
4
from lm_eval.base import BaseLM
Jason Phang's avatar
lib  
Jason Phang committed
5
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
6
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
7
import time
Leo Gao's avatar
Leo Gao committed
8
9
10


def get_result(response, ctxlen):
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
34

Leo Gao's avatar
Leo Gao committed
35
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
36
37


Leo Gao's avatar
Leo Gao committed
38
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
39
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
40

41
42
43
    Retry with back-off until they respond
    """
    import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
44

Leo Gao's avatar
Leo Gao committed
45
46
47
48
49
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
50
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
51

Leo Gao's avatar
Leo Gao committed
52
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
53
54
55
56
            time.sleep(backoff_time)
            backoff_time *= 1.5


Leo Gao's avatar
Leo Gao committed
57
class GPT3LM(BaseLM):
Leo Gao's avatar
Leo Gao committed
58
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
59
60
61
62
63
64
65
66
67

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
68
        super().__init__()
69

Jason Phang's avatar
Jason Phang committed
70
        import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
71

Jason Phang's avatar
gpt3  
Jason Phang committed
72
        self.engine = engine
Fabrizio Milo's avatar
Fabrizio Milo committed
73
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
Leo Gao's avatar
Leo Gao committed
74

75
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
76

Leo Gao's avatar
Leo Gao committed
77
78
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Fabrizio Milo's avatar
Fabrizio Milo committed
79
        assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
80
        self.truncate = truncate
Fabrizio Milo's avatar
Fabrizio Milo committed
81
82
83
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
            ["<|endoftext|>"]
        )[0]
Jason Phang's avatar
Jason Phang committed
84

Jason Phang's avatar
gpt3  
Jason Phang committed
85
86
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

111
112
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
113

114
115
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
116

117
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
118
119
        res = []

120
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
121
122
123
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
124
            toks = x[1] + x[2]
125
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
126

Fabrizio Milo's avatar
Fabrizio Milo committed
127
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
128

Fabrizio Milo's avatar
Fabrizio Milo committed
129
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
130
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
131
132
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
133
134
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
135
            for cache_key, context_enc, continuation_enc in chunk:
136
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
137
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
138
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
139
140
141
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
142
143
144
145

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
146
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
147
148
149
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
150
151
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
152
153
154
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
155
156
157
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
158
159
160
161
162
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
163
164
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
165

Fabrizio Milo's avatar
Fabrizio Milo committed
166
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
167
168

    def greedy_until(self, requests):
169
170
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
171
172
        res = []

173
        def _collate(x):
174
            toks = self.tok_encode(x[0])
175
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
176

Fabrizio Milo's avatar
Fabrizio Milo committed
177
        re_ord = utils.Reorderer(requests, _collate)
178

Leo Gao's avatar
Leo Gao committed
179
180
181
182
183
184
185
186
187
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
188

189
190
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
191

192
        # todo: more intelligent batching for heterogeneous `until`
Fabrizio Milo's avatar
Fabrizio Milo committed
193
        for chunk, until in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
194
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
195
        ):
Leo Gao's avatar
Leo Gao committed
196
197
            inps = []
            for context, _ in chunk:
198
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
199
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
200
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
201
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
202
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
203
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
204
205
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
206
                logprobs=10,
207
                stop=until["until"],
Leo Gao's avatar
Leo Gao committed
208
            )
Leo Gao's avatar
Leo Gao committed
209

210
            for resp, (context, until_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
211
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
212

213
                for term in until_:
Leo Gao's avatar
Leo Gao committed
214
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
215

Leo Gao's avatar
Leo Gao committed
216
                # partial caching
217
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
218

Leo Gao's avatar
Leo Gao committed
219
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
220

Fabrizio Milo's avatar
Fabrizio Milo committed
221
        return re_ord.get_original(res)
222
223
224
225
226
227
228
229

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()