gpt2.py 4.71 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

Leo Gao's avatar
Leo Gao committed
12
    def __init__(self, device=None, pretrained='gpt2'):
Leo Gao's avatar
Leo Gao committed
13
        super().__init__()
Leo Gao's avatar
Leo Gao committed
14
15
16
17
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
18
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
19
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
20
21

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
22
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
23
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
24
25
26
27
28
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
29

Leo Gao's avatar
Leo Gao committed
30
31
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
32
    @classmethod
Jason Phang's avatar
Jason Phang committed
33
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
34
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
35
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
36

Leo Gao's avatar
Leo Gao committed
37
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [50256]
            else:
                context_enc = self.tokenizer.encode(context)

            continuation_enc = self.tokenizer.encode(continuation)

            new_reqs.append((context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def _loglikelihood_tokens(self, requests):
Leo Gao's avatar
Leo Gao committed
53
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
54
        res = []
55
56
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
57
            # TODO: automatic batch size detection for vectorization
58
59

            def _collate(x):
Leo Gao's avatar
Leo Gao committed
60
61
                toks = x[0] + x[1]
                return (len(toks), tuple(toks))
62
63
            
            reord = utils.Reorderer(requests, _collate)
Leo Gao's avatar
Leo Gao committed
64
            for context_enc, continuation_enc in tqdm(reord.get_reordered()):
65
                # when too long to fit in context, truncate from the left
66
67
                inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
68
69

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
70
                logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
71

72
73
74
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

Leo Gao's avatar
Leo Gao committed
75
                last_token_slice = logits[:, -1, :].squeeze(0).tolist()
76

Leo Gao's avatar
Leo Gao committed
77
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
78

Leo Gao's avatar
Leo Gao committed
79
80
81
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
Leo Gao's avatar
Leo Gao committed
82
83
                # TODO: make sure that decode reverses correctly
                self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer)
Leo Gao's avatar
Leo Gao committed
84
85

                res.append(answer)
Leo Gao's avatar
Leo Gao committed
86

Leo Gao's avatar
Leo Gao committed
87
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
88
    
Leo Gao's avatar
Update  
Leo Gao committed
89
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
90
91
92
93
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

94
95
96
97
98
99
100
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
101
102
            if isinstance(until, str): until = [until]

103
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
104
105
106
107
108

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
109
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
110
111
112
113
114
115
116
117
118
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
119
120
121
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
122
123
            res.append(s)
        
124
        return reord.get_original(res)