"lm_eval/models/hf_causal.py" did not exist on "2da74953dee265e347298a788b4fd09f2c0344c2"
gpt2.py 4.15 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

Leo Gao's avatar
Leo Gao committed
12
13
14
15
16
    def __init__(self, device=None, pretrained='gpt2'):
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
17
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
18
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
19
20

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
21
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
22
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
28

Leo Gao's avatar
Leo Gao committed
29
30
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
31
    @classmethod
Jason Phang's avatar
Jason Phang committed
32
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
33
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
34
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
35

Leo Gao's avatar
Leo Gao committed
36
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
37
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
38
        res = []
39
40
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
41
            # TODO: automatic batch size detection for vectorization
42
43

            def _collate(x):
Leo Gao's avatar
Leo Gao committed
44
45
                toks = self.tokenizer.encode(x[0] + x[1])
                return (len(toks), x)
46
47
48
            
            reord = utils.Reorderer(requests, _collate)
            for context, continuation in tqdm(reord.get_reordered()):
49
                # when too long to fit in context, truncate from the left
Leo Gao's avatar
Leo Gao committed
50
                combined_toks = self.tokenizer.encode(context + continuation)
51
52
53
54
55
56

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
Leo Gao's avatar
Leo Gao committed
57

58
                continuation_enc = self.tokenizer.encode(continuation)
59
60
                inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
61
62
63

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
64

65
66
67
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

Leo Gao's avatar
Leo Gao committed
68
                last_token_slice = logits[:, -1, :].squeeze(0).tolist()
69

Leo Gao's avatar
Leo Gao committed
70
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
71

Leo Gao's avatar
Leo Gao committed
72
                res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
73

Leo Gao's avatar
Leo Gao committed
74
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
75
    
Leo Gao's avatar
Update  
Leo Gao committed
76
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
77
78
79
80
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

81
82
83
84
85
86
87
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
88
89
            if isinstance(until, str): until = [until]

90
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
91
92
93
94
95

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
96
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
97
98
99
100
101
102
103
104
105
106
107
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
108
        return reord.get_original(res)