gpt3.py 5.49 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import os
import transformers
Jason Phang's avatar
lib  
Jason Phang committed
3
4
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
5
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
6
import time
Leo Gao's avatar
Leo Gao committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22


def get_result(response, ctxlen):
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
    
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
23
24


Leo Gao's avatar
Leo Gao committed
25
26
27
28
29
30
31
32
33
34
35
36
def oa_completion(**kwargs):
    import openai

    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
            time.sleep(backoff_time)
            backoff_time *= 1.5


Jason Phang's avatar
gpt3  
Jason Phang committed
37
class GPT3LM(LM):
Jason Phang's avatar
Jason Phang committed
38
39

    MAX_LENGTH = 2048
Leo Gao's avatar
Leo Gao committed
40
    REQ_CHUNK_SIZE = 20
Leo Gao's avatar
Leo Gao committed
41
    MAX_GEN_TOKS = 256
Jason Phang's avatar
Jason Phang committed
42
43
44
45
46
47
48
49
50

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
51
        super().__init__()
Jason Phang's avatar
Jason Phang committed
52
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
53
        self.engine = engine
54
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
55

Leo Gao's avatar
Leo Gao committed
56

Leo Gao's avatar
Leo Gao committed
57
58
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
59
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
60
61
        self.truncate = truncate

Jason Phang's avatar
gpt3  
Jason Phang committed
62
63
64
65
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]

    @classmethod
Jason Phang's avatar
lib  
Jason Phang committed
66
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
gpt3  
Jason Phang committed
67
68
69
        args = utils.simple_parse_args_string(arg_string)
        return cls(engine=args.get("engine", "davinci"))

Leo Gao's avatar
Leo Gao committed
70
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
71
        import openai
Leo Gao's avatar
Leo Gao committed
72
73
        res = []

74
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
75
76
77
78
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
            toks = self.tokenizer.encode(x[0] + x[1])
79
80
81
82
83
            return (len(toks), self.tokenizer.decode(toks))
        
        reord = utils.Reorderer(requests, _collate)
        
        for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
84
85
86
            inps = []
            ctxlens = []
            for context, continuation in chunk:
Leo Gao's avatar
Leo Gao committed
87
88
89
90
91
92
                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                    
Leo Gao's avatar
Leo Gao committed
93
94
95
96
97
98
99
                continuation_enc = self.tokenizer.encode(continuation)
                inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
100
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
101
102
103
104
105
106
107
                engine=self.engine,
                prompt=inps,
                echo=True,
                max_tokens=0, temperature=0.,
                logprobs=10,
            )

Leo Gao's avatar
Leo Gao committed
108
109
110
111
112
113
114
            for resp, ctxlen, (context, continuation) in zip(response.choices, ctxlens, chunk):
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
                self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)
Leo Gao's avatar
Leo Gao committed
115
            
116
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
117
118

    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
119
        if not requests: return []
Leo Gao's avatar
Leo Gao committed
120
121
122
        import openai
        res = []

123
124
125
126
127
128
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

Leo Gao's avatar
Leo Gao committed
129
130
131
132
133
134
135
136
137
138
139
140
141
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
            
            if ret: yield ret, lastuntil

        # todo: more intelligent batching for heterogenous `until`
142
        for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
143
144
145
146
147
            inps = []
            for context, _ in chunk:
                context_enc = self.tokenizer.encode(context)
                inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
148

Leo Gao's avatar
Leo Gao committed
149
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
150
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
151
                prompt=inps,
Leo Gao's avatar
Leo Gao committed
152
153
154
                max_tokens=self.MAX_GEN_TOKS, 
                temperature=0.,
                logprobs=10,
Leo Gao's avatar
Leo Gao committed
155
                stop=until
Leo Gao's avatar
Leo Gao committed
156
            )
Leo Gao's avatar
Leo Gao committed
157

Leo Gao's avatar
Leo Gao committed
158
            for resp, (context, until) in zip(response.choices, chunk):
Leo Gao's avatar
Leo Gao committed
159
                s = resp['text']
Leo Gao's avatar
Leo Gao committed
160
161
162

                for term in until:
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
163

Leo Gao's avatar
Leo Gao committed
164
165
166
                # partial caching
                self.cache_hook.add_partial("greedy_until", (context, until), s)
                
Leo Gao's avatar
Leo Gao committed
167
                res.append(s)
Leo Gao's avatar
Leo Gao committed
168
        
Leo Gao's avatar
Leo Gao committed
169
        return reord.get_original(res)()