"megatron/2" did not exist on "5d29769cc044a7e4bc52c230321f6c59d1781cca"
gpt3.py 7.15 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
lintangsutawika's avatar
lintangsutawika committed
4
5
6

import numpy as np

Leo Gao's avatar
Leo Gao committed
7
from tqdm import tqdm
lintangsutawika's avatar
lintangsutawika committed
8
9
from lm_eval import utils
from lm_eval.api.model import LM, register_model
Leo Gao's avatar
Leo Gao committed
10
11
12


def get_result(response, ctxlen):
13
14
15
16
17
18
19
20
21
22
23
24
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
25
26
27
28
29
30
31
32
33
34
35
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
36

Leo Gao's avatar
Leo Gao committed
37
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
38
39


Leo Gao's avatar
Leo Gao committed
40
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
41
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
42

43
44
45
    Retry with back-off until they respond
    """
    import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
46

Leo Gao's avatar
Leo Gao committed
47
48
49
50
51
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
52
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
53

Leo Gao's avatar
Leo Gao committed
54
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
55
56
57
58
            time.sleep(backoff_time)
            backoff_time *= 1.5


59
@register_model("openai")
60
class GPT3LM(LM):
Leo Gao's avatar
Leo Gao committed
61
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
62
63
64
65
66
67
68
69
70

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
71
        super().__init__()
72

Jason Phang's avatar
Jason Phang committed
73
        import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
74

Jason Phang's avatar
gpt3  
Jason Phang committed
75
        self.engine = engine
Fabrizio Milo's avatar
Fabrizio Milo committed
76
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
Leo Gao's avatar
Leo Gao committed
77

78
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
79

Leo Gao's avatar
Leo Gao committed
80
81
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Fabrizio Milo's avatar
Fabrizio Milo committed
82
        assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
83
        self.truncate = truncate
Fabrizio Milo's avatar
Fabrizio Milo committed
84
85
86
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
            ["<|endoftext|>"]
        )[0]
Jason Phang's avatar
Jason Phang committed
87

Jason Phang's avatar
gpt3  
Jason Phang committed
88
89
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

114
115
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
116

117
118
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
119

120
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
121
122
        res = []

123
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
124
125
126
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
127
            toks = x[1] + x[2]
128
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
129

Fabrizio Milo's avatar
Fabrizio Milo committed
130
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
131

Fabrizio Milo's avatar
Fabrizio Milo committed
132
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
133
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
134
135
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
136
137
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
138
            for cache_key, context_enc, continuation_enc in chunk:
139
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
140
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
141
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
142
143
144
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
145
146
147
148

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
149
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
150
151
152
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
153
154
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
155
156
157
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
158
159
160
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
161
162
163
164
165
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
166
167
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
168

Fabrizio Milo's avatar
Fabrizio Milo committed
169
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
170
171

    def greedy_until(self, requests):
172
173
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
174
175
        res = []

176
        def _collate(x):
177
            toks = self.tok_encode(x[0])
178
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
179

Fabrizio Milo's avatar
Fabrizio Milo committed
180
        re_ord = utils.Reorderer(requests, _collate)
181

Leo Gao's avatar
Leo Gao committed
182
183
184
185
186
187
188
189
190
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
191

192
193
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
194

195
        # todo: more intelligent batching for heterogeneous `until`
Fabrizio Milo's avatar
Fabrizio Milo committed
196
        for chunk, until in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
197
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
198
        ):
Leo Gao's avatar
Leo Gao committed
199
200
            inps = []
            for context, _ in chunk:
201
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
202
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
203
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
204

Leo Gao's avatar
Leo Gao committed
205
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
206
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
207
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
208
209
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
210
                logprobs=10,
211
                stop=until,
Leo Gao's avatar
Leo Gao committed
212
            )
Leo Gao's avatar
Leo Gao committed
213

214
            for resp, (context, until_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
215
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
216

217
                for term in until_:
Leo Gao's avatar
Leo Gao committed
218
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
219

Leo Gao's avatar
Leo Gao committed
220
                # partial caching
221
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
222

Leo Gao's avatar
Leo Gao committed
223
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
224

Fabrizio Milo's avatar
Fabrizio Milo committed
225
        return re_ord.get_original(res)
226
227
228
229
230
231
232
233

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()