"lm_eval/models/openai.py" did not exist on "9454c839699708fed1efd152f82059de7b28ae7d"
gpt2.py 5.63 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
3
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
4
5


6
class HFLM(BaseLM):
cjlovering's avatar
cjlovering committed
7
8
9
10
11
12
13
14
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
        subfolder=None,
        tokenizer=None,
        batch_size=1,
15
        parallelize=False
cjlovering's avatar
cjlovering committed
16
    ):
17
18
19
20
21
22
23
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

        if device:
24
            self._device = torch.device(device)
25
        else:
cjlovering's avatar
cjlovering committed
26
27
28
29
30
            self._device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
31
32

        # TODO: update this to be less of a hack once subfolder is fixed in HF
33
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
cjlovering's avatar
cjlovering committed
34
35
            pretrained,
            revision=revision + ("/" + subfolder if subfolder is not None else ""),
36
        )
37
38
        self.gpt2.eval()

39
40
        # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
cjlovering's avatar
cjlovering committed
41
42
43
44
            pretrained if tokenizer is None else tokenizer,
            revision=revision,
            subfolder=subfolder,
        )
45

cjlovering's avatar
cjlovering committed
46
47
48
49
50
51
52
53
54
        assert isinstance(
            self.tokenizer,
            (
                transformers.GPT2Tokenizer,
                transformers.GPT2TokenizerFast,
                transformers.T5Tokenizer,
                transformers.T5TokenizerFast,
            ),
        ), "this tokenizer has not been checked for compatibility yet!"
55
56
57

        self.vocab_size = self.tokenizer.vocab_size

cjlovering's avatar
cjlovering committed
58
59
60
61
62
63
64
65
66
        if isinstance(
            self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
        ):
            assert self.tokenizer.encode("hello\n\nhello") == [
                31373,
                198,
                198,
                31373,
            ], self.tokenizer.encode("hello\n\nhello")
67
68

        # multithreading and batching
69
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
70
71

        # TODO: fix multi-gpu
72
73
74
75
76
        if parallelize:
            self.gpt2.parallelize()
            self._device = torch.device('cuda:0')
        else:
            self.gpt2.to(self._device)
77

Tian Yun's avatar
Tian Yun committed
78
79
80
81
    @property
    def eot_token(self):
        return self.tokenizer.eos_token

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device

109
110
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
cjlovering's avatar
cjlovering committed
111

112
113
114
115
116
117
118
119
120
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
121
        logits returned from the model
122
123
124
        """
        with torch.no_grad():
            return self.gpt2(inps)[0][:, :, :50257]
cjlovering's avatar
cjlovering committed
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def _get_stopping_criteria(self, stopping_criteria_ids):
        class MultitokenEOSCriteria(transformers.StoppingCriteria):
            def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
                self.eos_seq = tokenizer.decode(eos_seq_id)
                self.eos_seq_id = eos_seq_id
                self.eos_seq_len = len(eos_seq_id) + 1
                self.tokenizer = tokenizer

            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
                last_token_id = input_ids[0, -self.eos_seq_len:]
                last_tokens = self.tokenizer.decode(last_token_id)
                is_stopped = self.eos_seq in last_tokens
                return is_stopped
        
        class EOSCriteria(transformers.StoppingCriteria):
            def __init__(self, eos_token_id: torch.LongTensor):
                self.eos_token_id = eos_token_id

            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
                return input_ids[0,-1] == self.eos_token_id
         
        return transformers.StoppingCriteriaList([
            MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
Tian Yun's avatar
Tian Yun committed
149
            EOSCriteria(self.tokenizer.eos_token)
150
151
        ])

152
153
    def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
        stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
jon-tow's avatar
jon-tow committed
154
        max_length = max_length + context.size(1)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        if num_fewshot == 0:
            generations = self.gpt2.generate(
                context, 
                max_length=max_length, 
                eos_token_id=self.eot_token_id,
                do_sample=False,
            )
        else:
            generations = self.gpt2.generate(
                context, 
                max_length=max_length, 
                stopping_criteria=stopping_criteria,
                do_sample=False,
            )
169

170
171
        # Remove the context from the generations
        return generations[0, context.shape[1] :]
172

173
174
# for backwards compatibility
GPT2LM = HFLM