gpt2.py 6.54 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
3
import torch.nn as nn
Jason Phang's avatar
Jason Phang committed
4
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
5
6
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
7
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
8
9
10


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
11
12
    MAX_GEN_TOKS = 256

Leo Gao's avatar
Leo Gao committed
13
    def __init__(self, device=None, pretrained='gpt2'):
Leo Gao's avatar
Leo Gao committed
14
        super().__init__()
Leo Gao's avatar
Leo Gao committed
15
16
17
18
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
20
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
21
22

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
23
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
24
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
25
26
27
28
29
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
30

Leo Gao's avatar
Leo Gao committed
31
32
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

33
34
35
36
37
38
39
40
41
        # multithreading and batching
        gpus = torch.cuda.device_count()
        batch_size_per_gpu = 2 # todo: adaptive batch size

        self.batch_size = batch_size_per_gpu * gpus

        if gpus > 1:
            self.gpt2 = nn.DataParallel(self.gpt2)

Jason Phang's avatar
Jason Phang committed
42
    @classmethod
Jason Phang's avatar
Jason Phang committed
43
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
44
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
45
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
46

Leo Gao's avatar
Leo Gao committed
47
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
48
49
50
51
52
53
54
55
56
57
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [50256]
            else:
                context_enc = self.tokenizer.encode(context)

            continuation_enc = self.tokenizer.encode(continuation)

Leo Gao's avatar
Leo Gao committed
58
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
Leo Gao's avatar
Leo Gao committed
59
60
61
62

        return self._loglikelihood_tokens(new_reqs)

    def _loglikelihood_tokens(self, requests):
Leo Gao's avatar
Leo Gao committed
63
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
64
        res = []
65
        with torch.no_grad():
66
67

            def _collate(x):
68
69
70
71
72
73
                # the negative sign on len(toks) sorts descending - this has a few advantages:
                # - time estimates will always be over not underestimates, which is more useful for planning
                # - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
                #   this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
                # - any OOMs will happen right away rather than near the end

Leo Gao's avatar
Leo Gao committed
74
                toks = x[1] + x[2]
75
                return (-len(toks), tuple(toks))
76
            
77
            # TODO: automatic (variable) batch size detection for vectorization
78
            reord = utils.Reorderer(requests, _collate)
79
80
            for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
                inps = []
81
                inplens = []
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                ctxlens = []

                padding_length = None
                for _, context_enc, continuation_enc in chunk:
                    # when too long to fit in context, truncate from the left
                    inp = torch.tensor((context_enc + continuation_enc)[-self.max_length:], dtype=torch.long).to(self.device)
                    ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
                    inplen, = inp.shape

                    # since in _collate we make sure length is descending, the longest is always the first one.
                    padding_length = padding_length if padding_length is not None else inplen

                    # pad to length
                    inp = torch.cat([
                        inp, # [seq]
Leo Gao's avatar
Leo Gao committed
97
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
98
99
                    ], dim=0)

Leo Gao's avatar
Leo Gao committed
100
                    inps.append(inp.unsqueeze(0))
101
                    inplens.append(inplen)
102
103
                    ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
104
                multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1)  # [batch, seq, vocab]
105

106
                for (cache_key, _, _), logits, ctxlen, inp, inplen in zip(chunk, multi_logits, ctxlens, inps, inplens):
Leo Gao's avatar
Leo Gao committed
107
                    logits = logits[ctxlen - 1:inplen - 1].unsqueeze(0) # [1, seq, vocab]
Leo Gao's avatar
Leo Gao committed
108

109
                    greedy_tokens = logits.argmax(dim=-1)
110
                    cont_toks = inp[:, ctxlen:inplen]  # [1, seq]
111
                    max_equal = (greedy_tokens == cont_toks).all()
112

113
                    last_token_slice = logits[:, -1, :].squeeze(0).tolist()
114

Leo Gao's avatar
Leo Gao committed
115
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
116

117
                    answer = (float(logits.sum()), bool(max_equal))
Leo Gao's avatar
Leo Gao committed
118

119
120
121
                    # partial caching
                    if cache_key is not None:
                        self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Leo Gao's avatar
Leo Gao committed
122

123
                    res.append(answer)
Leo Gao's avatar
Leo Gao committed
124

Leo Gao's avatar
Leo Gao committed
125
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
126
    
Leo Gao's avatar
Update  
Leo Gao committed
127
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
128
129
130
131
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

132
133
134
135
136
137
138
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
139
140
            if isinstance(until, str): until = [until]

141
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
142
143
144
145
146

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
147
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
148
149
150
151
152
153
154
155
156
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
157
158
159
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
160
161
            res.append(s)
        
162
        return reord.get_original(res)