gpt2.py 6.25 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
Jason Phang committed
7
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
8
9
10


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
11
    MAX_GEN_TOKS = 256
Jason Phang's avatar
Jason Phang committed
12
13
    VOCAB_SIZE = 50257
    EOT_TOKEN_ID = 50256
Leo Gao's avatar
Leo Gao committed
14

Leo Gao's avatar
Leo Gao committed
15
    def __init__(self, device=None, pretrained='gpt2'):
Leo Gao's avatar
Leo Gao committed
16
        super().__init__()
Leo Gao's avatar
Leo Gao committed
17
18
19
20
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
21
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
22
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
23
24

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
25
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
26
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
27
28
29
30
31
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
32

Leo Gao's avatar
Leo Gao committed
33
34
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
35
    @classmethod
Jason Phang's avatar
Jason Phang committed
36
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
37
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
38
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
39

Leo Gao's avatar
Leo Gao committed
40
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
41
42
43
44
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
Jason Phang's avatar
Jason Phang committed
45
                context_enc = [self.EOT_TOKEN_ID]
Leo Gao's avatar
Leo Gao committed
46
47
48
49
50
            else:
                context_enc = self.tokenizer.encode(context)

            continuation_enc = self.tokenizer.encode(continuation)

Leo Gao's avatar
Leo Gao committed
51
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
Leo Gao's avatar
Leo Gao committed
52
53
54

        return self._loglikelihood_tokens(new_reqs)

Jason Phang's avatar
Jason Phang committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    def loglikelihood_perplexity(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        with torch.no_grad():
            for string, in tqdm(requests):
                encoded = self.tokenizer.encode_plus(string)["input_ids"]
                rolling_token_windows = utils.get_rolling_token_windows(
                    token_list=encoded,
                    prefix_token=self.EOT_TOKEN_ID,
                    max_seq_len=self.max_length,
                    context_len=1,
                )
                string_nll = []
                for input_tokens, pred_tokens in rolling_token_windows:
                    inp = torch.tensor([input_tokens], dtype=torch.long).to(self.device)
                    labels = torch.tensor([pred_tokens], dtype=torch.long).to(self.device)
                    logits = F.log_softmax(self.gpt2(inp)[0][:, :, :self.VOCAB_SIZE], dim=-1)  # [batch, seq, vocab]
                    # Only score the relevant logits (i.e. the last len(pred_tokens) logits
                    scoring_logits = logits[:, -len(pred_tokens):].reshape(len(pred_tokens), self.VOCAB_SIZE)
                    string_nll.append(
                        F.cross_entropy(scoring_logits, target=labels.view(-1), reduction="none").cpu().numpy()
                    )
                string_nll = np.concatenate(string_nll)
                loglikelihoods.append(-string_nll)

        return loglikelihoods

Leo Gao's avatar
Leo Gao committed
84
    def _loglikelihood_tokens(self, requests):
Leo Gao's avatar
Leo Gao committed
85
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
86
        res = []
87
88
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
89
            # TODO: automatic batch size detection for vectorization
90
91

            def _collate(x):
Leo Gao's avatar
Leo Gao committed
92
                toks = x[1] + x[2]
Leo Gao's avatar
Leo Gao committed
93
                return (len(toks), tuple(toks))
Jason Phang's avatar
Jason Phang committed
94

95
            reord = utils.Reorderer(requests, _collate)
Leo Gao's avatar
Leo Gao committed
96
            for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
97
                # when too long to fit in context, truncate from the left
98
99
                inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
100
101

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
Jason Phang's avatar
Jason Phang committed
102
                logits = F.log_softmax(self.gpt2(inp)[0][:, :, :self.VOCAB_SIZE], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
103

104
105
106
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

Leo Gao's avatar
Leo Gao committed
107
                last_token_slice = logits[:, -1, :].squeeze(0).tolist()
108

Leo Gao's avatar
Leo Gao committed
109
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
110

Leo Gao's avatar
Leo Gao committed
111
112
113
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
Leo Gao's avatar
Leo Gao committed
114
115
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Leo Gao's avatar
Leo Gao committed
116
117

                res.append(answer)
Leo Gao's avatar
Leo Gao committed
118

Leo Gao's avatar
Leo Gao committed
119
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
120
    
Leo Gao's avatar
Update  
Leo Gao committed
121
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
122
123
124
125
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

126
127
128
129
130
131
132
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
133
134
            if isinstance(until, str): until = [until]

135
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
136
137
138
139
140

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
141
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
142
143
144
145
146
147
148
149
150
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
151
152
153
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
154
155
            res.append(s)
        
156
        return reord.get_original(res)