gpt2.py 3.06 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

12
    def __init__(self, device="cpu", pretrained='gpt2'):
Jason Phang's avatar
Jason Phang committed
13
        self.device = torch.device(device)
14
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
15
        self.gpt2.eval()
16
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
Leo Gao's avatar
Leo Gao committed
17
        self.tokenizer.pad_token = "<|endoftext|>"
Jason Phang's avatar
Jason Phang committed
18
19

    @classmethod
Jason Phang's avatar
Jason Phang committed
20
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
21
        args = utils.simple_parse_args_string(arg_string)
22
        return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
23

Leo Gao's avatar
Leo Gao committed
24
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
25
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
26
        res = []
27
28
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
29
            # TODO: automatic batch size detection for vectorization
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
            for context, continuation in tqdm(requests):
                # when too long to fit in context, truncate from the left

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                
                continuation_enc = self.tokenizer.encode(continuation)
                inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
                
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]


                res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
53
54
55

        return res
    
Leo Gao's avatar
Update  
Leo Gao committed
56
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
57
58
59
60
61
62
63
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

        for context, until in tqdm(requests):
            if isinstance(until, str): until = [until]

Leo Gao's avatar
Leo Gao committed
64
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
65
66
67
68
69

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
70
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
71
72
73
74
75
76
77
78
79
80
81
82
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
        return res