gpt2.py 2.95 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

12
    def __init__(self, device="cpu", pretrained='gpt2'):
Jason Phang's avatar
Jason Phang committed
13
        self.device = torch.device(device)
14
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
15
        self.gpt2.eval()
16
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
Leo Gao's avatar
Leo Gao committed
17
        self.tokenizer.pad_token = "<|endoftext|>"
Jason Phang's avatar
Jason Phang committed
18
19

    @classmethod
Jason Phang's avatar
Jason Phang committed
20
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
21
        args = utils.simple_parse_args_string(arg_string)
22
        return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
23

Leo Gao's avatar
Leo Gao committed
24
25
    def loglikelihood(self, requests):
        res = []
26
27
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
28
            # TODO: automatic batch size detection for vectorization
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
            for context, continuation in tqdm(requests):
                # when too long to fit in context, truncate from the left

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                
                continuation_enc = self.tokenizer.encode(continuation)
                inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
                
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]


                res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
52
53
54

        return res
    
Leo Gao's avatar
Update  
Leo Gao committed
55
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
56
57
58
59
60
61
62
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

        for context, until in tqdm(requests):
            if isinstance(until, str): until = [until]

Leo Gao's avatar
Leo Gao committed
63
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
64
65
66
67
68

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
69
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
70
71
72
73
74
75
76
77
78
79
80
81
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
        return res