gpt2.py 6.53 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
Jason Phang committed
7
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
8
9
10


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
11
    MAX_GEN_TOKS = 256
Jason Phang's avatar
Jason Phang committed
12
13
    VOCAB_SIZE = 50257
    EOT_TOKEN_ID = 50256
Leo Gao's avatar
Leo Gao committed
14

Leo Gao's avatar
Leo Gao committed
15
    def __init__(self, device=None, pretrained='gpt2'):
Leo Gao's avatar
Leo Gao committed
16
        super().__init__()
Leo Gao's avatar
Leo Gao committed
17
18
19
20
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
21
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
22
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
23
24

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
25
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
26
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
27
28
29
30
31
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
32

Leo Gao's avatar
Leo Gao committed
33
34
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
35
    @classmethod
Jason Phang's avatar
Jason Phang committed
36
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
37
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
38
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
39

Leo Gao's avatar
Leo Gao committed
40
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
41
42
43
44
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
Jason Phang's avatar
Jason Phang committed
45
                context_enc = [self.EOT_TOKEN_ID]
Leo Gao's avatar
Leo Gao committed
46
47
48
49
50
            else:
                context_enc = self.tokenizer.encode(context)

            continuation_enc = self.tokenizer.encode(continuation)

Leo Gao's avatar
Leo Gao committed
51
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
Leo Gao's avatar
Leo Gao committed
52
53
54

        return self._loglikelihood_tokens(new_reqs)

Jason Phang's avatar
Jason Phang committed
55
56
57
58
59
60
61
62
    def loglikelihood_perplexity(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        with torch.no_grad():
            for string, in tqdm(requests):
                encoded = self.tokenizer.encode_plus(string)["input_ids"]
Leo Gao's avatar
Leo Gao committed
63

Leo Gao's avatar
Leo Gao committed
64
                rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
Jason Phang's avatar
Jason Phang committed
65
66
67
68
                    token_list=encoded,
                    prefix_token=self.EOT_TOKEN_ID,
                    max_seq_len=self.max_length,
                    context_len=1,
Leo Gao's avatar
Leo Gao committed
69
70
71
72
                )))

                rolling_token_windows = [(None,) + x for x in rolling_token_windows]

Leo Gao's avatar
Leo Gao committed
73
                # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
Leo Gao's avatar
Leo Gao committed
74
75
76
77
78
79
80
                string_nll = self._loglikelihood_tokens(rolling_token_windows)
                
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]
                
                string_nll = sum(string_nll)
                loglikelihoods.append(string_nll)
Jason Phang's avatar
Jason Phang committed
81
82
83

        return loglikelihoods

Leo Gao's avatar
Leo Gao committed
84
    def _loglikelihood_tokens(self, requests):
Leo Gao's avatar
Leo Gao committed
85
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
86
        res = []
87
88
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
89
            # TODO: automatic batch size detection for vectorization
90
91

            def _collate(x):
Leo Gao's avatar
Leo Gao committed
92
                toks = x[1] + x[2]
Leo Gao's avatar
Leo Gao committed
93
                return (len(toks), tuple(toks))
Jason Phang's avatar
Jason Phang committed
94

95
            reord = utils.Reorderer(requests, _collate)
Leo Gao's avatar
Leo Gao committed
96
            for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
97
98
99
100
101
102
103
104
105
106
107
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length
                
                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
                # gpt2    \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
                # cont_toks      4 5 6 7 8 9

108
                # when too long to fit in context, truncate from the left
Leo Gao's avatar
Leo Gao committed
109
110
111
112
113
                inp = torch.tensor([
                    (context_enc + continuation_enc)[-(self.max_length+1):] 
                ], dtype=torch.long).to(self.device)

                cont_toks = inp[:, -len(continuation_enc):]  # [batch, seq]
114

Leo Gao's avatar
Leo Gao committed
115
                logits = F.log_softmax(self.gpt2(inp[:, :-1])[0][:, -len(continuation_enc):, :self.VOCAB_SIZE], dim=-1)  # [batch, seq, vocab] - vocab size is clipped to exclude padding tokens or whatever
Leo Gao's avatar
Leo Gao committed
116

117
118
119
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

Leo Gao's avatar
Leo Gao committed
120
                last_token_slice = logits[:, -1, :].squeeze(0).tolist()
121

Leo Gao's avatar
Leo Gao committed
122
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
123

Leo Gao's avatar
Leo Gao committed
124
                answer = (float(logits.cpu().to(torch.float64).sum()), bool(max_equal))
Leo Gao's avatar
Leo Gao committed
125
126

                # partial caching
Leo Gao's avatar
Leo Gao committed
127
128
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Leo Gao's avatar
Leo Gao committed
129
130

                res.append(answer)
Leo Gao's avatar
Leo Gao committed
131

Leo Gao's avatar
Leo Gao committed
132
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
133
    
Leo Gao's avatar
Update  
Leo Gao committed
134
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
135
136
137
138
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

139
140
141
142
143
144
145
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
146
147
            if isinstance(until, str): until = [until]

148
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
149
150
151
152
153

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
154
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
155
156
157
158
159
160
161
162
163
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
164
165
166
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
167
168
            res.append(s)
        
169
        return reord.get_original(res)