gpt2.py 3.44 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

Leo Gao's avatar
Leo Gao committed
12
13
14
15
16
    def __init__(self, device=None, pretrained='gpt2'):
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
17
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
18
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
19
20
21

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
        self.tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
22
        self.tokenizer.pad_token = "<|endoftext|>"
23
        self.max_length = self.gpt2.config.n_ctx
Jason Phang's avatar
Jason Phang committed
24

Leo Gao's avatar
Leo Gao committed
25
26
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
27
    @classmethod
Jason Phang's avatar
Jason Phang committed
28
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
29
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
30
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
31

Leo Gao's avatar
Leo Gao committed
32
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
33
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
34
        res = []
35
36
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
37
            # TODO: automatic batch size detection for vectorization
38
39
40
41
42
43
44
45
46
47
            for context, continuation in tqdm(requests):
                # when too long to fit in context, truncate from the left

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                
                continuation_enc = self.tokenizer.encode(continuation)
48
49
                inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
50
51
52
53
54
55
56
57
58
59
60

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
                
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]


                res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
61
62
63

        return res
    
Leo Gao's avatar
Update  
Leo Gao committed
64
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
65
66
67
68
69
70
71
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

        for context, until in tqdm(requests):
            if isinstance(until, str): until = [until]

72
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
73
74
75
76
77

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
78
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
79
80
81
82
83
84
85
86
87
88
89
90
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
        return res