gpt3.py 7.14 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Jason Phang's avatar
Jason Phang committed
2
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
4
from lm_eval.api.model import LM, register_model
Jason Phang's avatar
lib  
Jason Phang committed
5
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
6
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
7
import time
Leo Gao's avatar
Leo Gao committed
8
9
10


def get_result(response, ctxlen):
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
34

Leo Gao's avatar
Leo Gao committed
35
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
36
37


Leo Gao's avatar
Leo Gao committed
38
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
39
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
40

41
42
43
    Retry with back-off until they respond
    """
    import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
44

Leo Gao's avatar
Leo Gao committed
45
46
47
48
49
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
50
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
51

Leo Gao's avatar
Leo Gao committed
52
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
53
54
55
56
            time.sleep(backoff_time)
            backoff_time *= 1.5


57
@register_model("openai")
58
class GPT3LM(LM):
Leo Gao's avatar
Leo Gao committed
59
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
60
61
62
63
64
65
66
67
68

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
69
        super().__init__()
70

Jason Phang's avatar
Jason Phang committed
71
        import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
72

Jason Phang's avatar
gpt3  
Jason Phang committed
73
        self.engine = engine
Fabrizio Milo's avatar
Fabrizio Milo committed
74
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
Leo Gao's avatar
Leo Gao committed
75

76
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
77

Leo Gao's avatar
Leo Gao committed
78
79
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Fabrizio Milo's avatar
Fabrizio Milo committed
80
        assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
81
        self.truncate = truncate
Fabrizio Milo's avatar
Fabrizio Milo committed
82
83
84
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
            ["<|endoftext|>"]
        )[0]
Jason Phang's avatar
Jason Phang committed
85

Jason Phang's avatar
gpt3  
Jason Phang committed
86
87
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

112
113
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
114

115
116
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
117

118
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
119
120
        res = []

121
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
122
123
124
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
125
            toks = x[1] + x[2]
126
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
127

Fabrizio Milo's avatar
Fabrizio Milo committed
128
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
129

Fabrizio Milo's avatar
Fabrizio Milo committed
130
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
131
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
132
133
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
134
135
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
136
            for cache_key, context_enc, continuation_enc in chunk:
137
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
138
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
139
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
140
141
142
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
143
144
145
146

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
147
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
148
149
150
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
151
152
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
153
154
155
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
156
157
158
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
159
160
161
162
163
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
164
165
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
166

Fabrizio Milo's avatar
Fabrizio Milo committed
167
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
168
169

    def greedy_until(self, requests):
170
171
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
172
173
        res = []

174
        def _collate(x):
175
            toks = self.tok_encode(x[0])
176
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
177

Fabrizio Milo's avatar
Fabrizio Milo committed
178
        re_ord = utils.Reorderer(requests, _collate)
179

Leo Gao's avatar
Leo Gao committed
180
181
182
183
184
185
186
187
188
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
189

190
191
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
192

193
        # todo: more intelligent batching for heterogeneous `until`
Fabrizio Milo's avatar
Fabrizio Milo committed
194
        for chunk, until in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
195
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
196
        ):
Leo Gao's avatar
Leo Gao committed
197
198
            inps = []
            for context, _ in chunk:
199
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
200
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
201
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
202

Leo Gao's avatar
Leo Gao committed
203
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
204
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
205
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
206
207
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
208
                logprobs=10,
209
                stop=until,
Leo Gao's avatar
Leo Gao committed
210
            )
Leo Gao's avatar
Leo Gao committed
211

212
            for resp, (context, until_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
213
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
214

215
                for term in until_:
Leo Gao's avatar
Leo Gao committed
216
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
217

Leo Gao's avatar
Leo Gao committed
218
                # partial caching
219
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
220

Leo Gao's avatar
Leo Gao committed
221
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
222

Fabrizio Milo's avatar
Fabrizio Milo committed
223
        return re_ord.get_original(res)
224
225
226
227
228
229
230
231

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()