gpt2.py 4.47 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

Leo Gao's avatar
Leo Gao committed
12
    def __init__(self, device=None, pretrained='gpt2'):
Leo Gao's avatar
Leo Gao committed
13
        super().__init__()
Leo Gao's avatar
Leo Gao committed
14
15
16
17
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
18
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
19
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
20
21

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
22
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
23
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
24
25
26
27
28
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
29

Leo Gao's avatar
Leo Gao committed
30
31
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
32
    @classmethod
Jason Phang's avatar
Jason Phang committed
33
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
34
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
35
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
36

Leo Gao's avatar
Leo Gao committed
37
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
38
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
39
        res = []
40
41
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
42
            # TODO: automatic batch size detection for vectorization
43
44

            def _collate(x):
Leo Gao's avatar
Leo Gao committed
45
46
                toks = self.tokenizer.encode(x[0] + x[1])
                return (len(toks), x)
47
48
49
            
            reord = utils.Reorderer(requests, _collate)
            for context, continuation in tqdm(reord.get_reordered()):
50
                # when too long to fit in context, truncate from the left
Leo Gao's avatar
Leo Gao committed
51
                combined_toks = self.tokenizer.encode(context + continuation)
52
53
54
55
56
57

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
Leo Gao's avatar
Leo Gao committed
58

59
                continuation_enc = self.tokenizer.encode(continuation)
60
61
                inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
62
63

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
64
                logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
65

66
67
68
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

Leo Gao's avatar
Leo Gao committed
69
                last_token_slice = logits[:, -1, :].squeeze(0).tolist()
70

Leo Gao's avatar
Leo Gao committed
71
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
72

Leo Gao's avatar
Leo Gao committed
73
74
75
76
77
78
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)

                res.append(answer)
Leo Gao's avatar
Leo Gao committed
79

Leo Gao's avatar
Leo Gao committed
80
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
81
    
Leo Gao's avatar
Update  
Leo Gao committed
82
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
83
84
85
86
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

87
88
89
90
91
92
93
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
94
95
            if isinstance(until, str): until = [until]

96
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
97
98
99
100
101

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
102
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
103
104
105
106
107
108
109
110
111
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
112
113
114
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
115
116
            res.append(s)
        
117
        return reord.get_original(res)