gpt2.py 2.03 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
10
    def __init__(self, device="cpu", pretrained='gpt2'):
Jason Phang's avatar
Jason Phang committed
11
        self.device = torch.device(device)
12
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
13
        self.gpt2.eval()
14
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
Leo Gao's avatar
Leo Gao committed
15
        self.tokenizer.pad_token = "<|endoftext|>"
Jason Phang's avatar
Jason Phang committed
16
17

    @classmethod
Jason Phang's avatar
Jason Phang committed
18
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
19
        args = utils.simple_parse_args_string(arg_string)
20
        return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
21

Leo Gao's avatar
Leo Gao committed
22
23
    def loglikelihood(self, requests):
        res = []
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        with torch.no_grad():
            # TODO: vectorize properly
            for context, continuation in tqdm(requests):
                # when too long to fit in context, truncate from the left

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                
                continuation_enc = self.tokenizer.encode(continuation)
                inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
                
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]


                res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
49
50
51

        return res
    
Leo Gao's avatar
Update  
Leo Gao committed
52
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
53
        # TODO: implement
Leo Gao's avatar
Leo Gao committed
54
        pass