"tests/blenderbot/test_modeling_tf_blenderbot.py" did not exist on "543d0549f8337606d723051d9a349f2324b1b559"
gpt3.py 6.93 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
Jason Phang's avatar
Jason Phang committed
2
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
4
from lm_eval.base import BaseLM
Jason Phang's avatar
lib  
Jason Phang committed
5
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
6
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
7
import time
Leo Gao's avatar
Leo Gao committed
8
9
10


def get_result(response, ctxlen):
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
34
35
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
    
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
36
37


Leo Gao's avatar
Leo Gao committed
38
def oa_completion(**kwargs):
39
    """ Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
40

41
42
43
    Retry with back-off until they respond
    """
    import openai
Leo Gao's avatar
Leo Gao committed
44
45
46
47
48
49
50
51
52
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
            time.sleep(backoff_time)
            backoff_time *= 1.5


Leo Gao's avatar
Leo Gao committed
53
class GPT3LM(BaseLM):
Leo Gao's avatar
Leo Gao committed
54
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
55
56
57
58
59
60
61
62
63

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
64
        super().__init__()
65

Jason Phang's avatar
Jason Phang committed
66
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
67
        self.engine = engine
68
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
69

70
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
71

Leo Gao's avatar
Leo Gao committed
72
73
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
74
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
75
        self.truncate = truncate
Jason Phang's avatar
Jason Phang committed
76
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
Jason Phang's avatar
Jason Phang committed
77

Jason Phang's avatar
gpt3  
Jason Phang committed
78
79
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

104
105
106
107
108
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
    
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
109

110
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
111
112
        res = []

113
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
114
115
116
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
117
            toks = x[1] + x[2]
118
            return -len(toks), tuple(toks)
119
120
        
        reord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
121

122
        for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
Leo Gao's avatar
Leo Gao committed
123
124
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
125
            for cache_key, context_enc, continuation_enc in chunk:
126
127
128
129
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
                inp = (context_enc + continuation_enc)[-(self.max_length+1):]
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
Leo Gao's avatar
Leo Gao committed
130
131
132
133

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
134
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
135
136
137
138
139
140
141
                engine=self.engine,
                prompt=inps,
                echo=True,
                max_tokens=0, temperature=0.,
                logprobs=10,
            )

Leo Gao's avatar
Leo Gao committed
142
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
Leo Gao's avatar
Leo Gao committed
143
144
145
146
147
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
148
149
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
150

151
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
152
153

    def greedy_until(self, requests):
154
155
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
156
157
        res = []

158
        def _collate(x):
159
            toks = self.tok_encode(x[0])
160
            return len(toks), x[0]
161
162
163
        
        reord = utils.Reorderer(requests, _collate)

Leo Gao's avatar
Leo Gao committed
164
165
166
167
168
169
170
171
172
173
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
            
174
175
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
176

177
        # todo: more intelligent batching for heterogeneous `until`
178
        for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
179
180
            inps = []
            for context, _ in chunk:
181
182
                context_enc = self.tok_encode(context)
                inp = context_enc[-(self.max_length - self.max_gen_toks):]
Leo Gao's avatar
Leo Gao committed
183
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
184

Leo Gao's avatar
Leo Gao committed
185
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
186
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
187
                prompt=inps,
188
                max_tokens=self.max_gen_toks, 
Leo Gao's avatar
Leo Gao committed
189
190
                temperature=0.,
                logprobs=10,
191
                stop=until,
Leo Gao's avatar
Leo Gao committed
192
            )
Leo Gao's avatar
Leo Gao committed
193

194
            for resp, (context, until_) in zip(response.choices, chunk):
Leo Gao's avatar
Leo Gao committed
195
                s = resp['text']
Leo Gao's avatar
Leo Gao committed
196

197
                for term in until_:
Leo Gao's avatar
Leo Gao committed
198
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
199

Leo Gao's avatar
Leo Gao committed
200
                # partial caching
201
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
Leo Gao's avatar
Leo Gao committed
202
                
Leo Gao's avatar
Leo Gao committed
203
                res.append(s)
Leo Gao's avatar
Leo Gao committed
204
        
Leo Gao's avatar
Leo Gao committed
205
        return reord.get_original(res)
206
207
208
209
210
211
212
213

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()