"include/ck/utility/amd_inline_asm.hpp" did not exist on "b6e1c52a80086b1cc711ad0270f34e6fd1181709"
gpt2.py 4.29 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

Leo Gao's avatar
Leo Gao committed
12
13
14
15
16
    def __init__(self, device=None, pretrained='gpt2'):
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
17
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
18
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
19
20
21

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
        self.tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
22
        self.tokenizer.pad_token = "<|endoftext|>"
23
        self.max_length = self.gpt2.config.n_ctx
Jason Phang's avatar
Jason Phang committed
24

Leo Gao's avatar
Leo Gao committed
25
26
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
27
    @classmethod
Jason Phang's avatar
Jason Phang committed
28
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
29
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
30
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
31

Leo Gao's avatar
Leo Gao committed
32
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
33
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
34
        res = []
35
36
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
37
            # TODO: automatic batch size detection for vectorization
38
39
40
41
42
43
44

            def _collate(x):
                toks = self.tokenizer.encode(x[0] + x[1])[:-1]
                return (len(toks), self.tokenizer.decode(toks))
            
            reord = utils.Reorderer(requests, _collate)
            for context, continuation in tqdm(reord.get_reordered()):
45
46
47
48
49
50
51
52
53
                # when too long to fit in context, truncate from the left

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                
                continuation_enc = self.tokenizer.encode(continuation)
54
55
                inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
56
57
58
59
60
61
62

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
                
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

Leo Gao's avatar
Leo Gao committed
63
                last_token_slice = logits[:, -1, :].squeeze(0).tolist()
64

Leo Gao's avatar
Leo Gao committed
65
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
66

Leo Gao's avatar
Leo Gao committed
67
                res.append((float(logits[:, :-1].sum() if logits.shape[-1] > 1 else 0), last_token_slice, bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
68

Leo Gao's avatar
Leo Gao committed
69
70
71
72
        # optimization: if two requests have everything the same except the last token, use 
        # last token distribution to save compute
        lasttoks = [self.tokenizer.encode(x[1])[-1] for x in requests]
        return [(l + lts[lasttok], m) for (l, lts, m), lasttok in zip(reord.get_original(res), lasttoks)]
Leo Gao's avatar
Leo Gao committed
73
    
Leo Gao's avatar
Update  
Leo Gao committed
74
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
75
76
77
78
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

79
80
81
82
83
84
85
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
86
87
            if isinstance(until, str): until = [until]

88
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
89
90
91
92
93

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
94
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
95
96
97
98
99
100
101
102
103
104
105
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
106
        return reord.get_original(res)