gpt3.py 7.11 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Jason Phang's avatar
Jason Phang committed
2
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
4
from lm_eval.base import BaseLM
Jason Phang's avatar
lib  
Jason Phang committed
5
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
6
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
7
import time
Leo Gao's avatar
Leo Gao committed
8
9
10


def get_result(response, ctxlen):
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
bzantium's avatar
bzantium committed
34

Leo Gao's avatar
Leo Gao committed
35
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
36
37


Leo Gao's avatar
Leo Gao committed
38
def oa_completion(**kwargs):
bzantium's avatar
bzantium committed
39
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
40

41
42
43
    Retry with back-off until they respond
    """
    import openai
bzantium's avatar
bzantium committed
44

Leo Gao's avatar
Leo Gao committed
45
46
47
48
49
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
50
            import traceback
bzantium's avatar
bzantium committed
51

Leo Gao's avatar
Leo Gao committed
52
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
53
54
55
56
            time.sleep(backoff_time)
            backoff_time *= 1.5


Leo Gao's avatar
Leo Gao committed
57
class GPT3LM(BaseLM):
Leo Gao's avatar
Leo Gao committed
58
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
59
60
61
62
63
64
65
66
67

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
68
        super().__init__()
69

Jason Phang's avatar
Jason Phang committed
70
        import openai
bzantium's avatar
bzantium committed
71

Jason Phang's avatar
gpt3  
Jason Phang committed
72
        self.engine = engine
bzantium's avatar
bzantium committed
73
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
Leo Gao's avatar
Leo Gao committed
74

75
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
76

Leo Gao's avatar
Leo Gao committed
77
78
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
bzantium's avatar
bzantium committed
79
        assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
80
        self.truncate = truncate
bzantium's avatar
bzantium committed
81
82
83
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
            ["<|endoftext|>"]
        )[0]
Jason Phang's avatar
Jason Phang committed
84

Jason Phang's avatar
gpt3  
Jason Phang committed
85
86
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

111
112
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
bzantium's avatar
bzantium committed
113

114
115
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
116

117
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
118
119
        res = []

120
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
121
122
123
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
124
            toks = x[1] + x[2]
125
            return -len(toks), tuple(toks)
Jason Phang's avatar
Jason Phang committed
126

bzantium's avatar
bzantium committed
127
128
129
130
131
132
        re_ord = utils.Reorderer(requests, _collate)

        for chunk in tqdm(
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
133
134
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
135
            for cache_key, context_enc, continuation_enc in chunk:
136
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
bzantium's avatar
bzantium committed
137
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
138
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
bzantium's avatar
bzantium committed
139
140
141
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
142
143
144
145

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
146
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
147
148
149
                engine=self.engine,
                prompt=inps,
                echo=True,
bzantium's avatar
bzantium committed
150
151
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
152
153
154
                logprobs=10,
            )

bzantium's avatar
bzantium committed
155
156
157
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
158
159
160
161
162
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
163
164
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
165

bzantium's avatar
bzantium committed
166
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
167
168

    def greedy_until(self, requests):
169
170
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
171
172
        res = []

173
        def _collate(x):
174
            toks = self.tok_encode(x[0])
175
            return len(toks), x[0]
bzantium's avatar
bzantium committed
176
177

        re_ord = utils.Reorderer(requests, _collate)
178

Leo Gao's avatar
Leo Gao committed
179
180
181
182
183
184
185
186
187
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
bzantium's avatar
bzantium committed
188

189
190
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
191

192
        # todo: more intelligent batching for heterogeneous `until`
bzantium's avatar
bzantium committed
193
194
195
        for chunk, until in tqdm(
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
        ):
Leo Gao's avatar
Leo Gao committed
196
197
            inps = []
            for context, _ in chunk:
198
                context_enc = self.tok_encode(context)
bzantium's avatar
bzantium committed
199
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
200
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
201

Leo Gao's avatar
Leo Gao committed
202
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
203
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
204
                prompt=inps,
bzantium's avatar
bzantium committed
205
206
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
207
                logprobs=10,
208
                stop=until,
Leo Gao's avatar
Leo Gao committed
209
            )
Leo Gao's avatar
Leo Gao committed
210

211
            for resp, (context, until_) in zip(response.choices, chunk):
bzantium's avatar
bzantium committed
212
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
213

214
                for term in until_:
Leo Gao's avatar
Leo Gao committed
215
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
216

Leo Gao's avatar
Leo Gao committed
217
                # partial caching
218
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
bzantium's avatar
bzantium committed
219

Leo Gao's avatar
Leo Gao committed
220
                res.append(s)
bzantium's avatar
bzantium committed
221
222

        return re_ord.get_original(res)
223
224
225
226
227
228
229
230

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()