gpt2.py 3.03 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

12
    def __init__(self, device="cpu", pretrained='gpt2'):
Jason Phang's avatar
Jason Phang committed
13
        self.device = torch.device(device)
14
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
15
        self.gpt2.eval()
16
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
Leo Gao's avatar
Leo Gao committed
17
        self.tokenizer.pad_token = "<|endoftext|>"
Jason Phang's avatar
Jason Phang committed
18

Leo Gao's avatar
Leo Gao committed
19
20
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
21
    @classmethod
Jason Phang's avatar
Jason Phang committed
22
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
23
        args = utils.simple_parse_args_string(arg_string)
24
        return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
25

Leo Gao's avatar
Leo Gao committed
26
27
    def loglikelihood(self, requests):
        res = []
28
29
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
30
            # TODO: automatic batch size detection for vectorization
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
            for context, continuation in tqdm(requests):
                # when too long to fit in context, truncate from the left

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                
                continuation_enc = self.tokenizer.encode(continuation)
                inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
                
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]


                res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
54
55
56

        return res
    
Leo Gao's avatar
Update  
Leo Gao committed
57
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
58
59
60
61
62
63
64
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

        for context, until in tqdm(requests):
            if isinstance(until, str): until = [until]

Leo Gao's avatar
Leo Gao committed
65
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
66
67
68
69
70

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
71
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
72
73
74
75
76
77
78
79
80
81
82
83
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
        return res