gpt2.py 5.57 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
3
from typing import Optional, Union
4
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
5

6

7
8
9
10
11
12
13
14
15
16
17
18
def _get_dtype(
    dtype: Union[str, torch.dtype]
) -> torch.dtype:
    """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
    if isinstance(dtype, str) and dtype != "auto":
        # Convert `str` args torch dtype: `float16` -> `torch.float16`
        _torch_dtype = getattr(torch, dtype)
    else:
        _torch_dtype = dtype
    return _torch_dtype


19
class HFLM(BaseLM):
Fabrizio Milo's avatar
Fabrizio Milo committed
20
21
22
23
24
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
25
        low_cpu_mem_usage=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
26
27
28
        subfolder=None,
        tokenizer=None,
        batch_size=1,
29
30
        load_in_8bit: Optional[bool] = False,
        trust_remote_code: Optional[bool] = False,
31
        dtype: Optional[Union[str, torch.dtype]]="auto",
Fabrizio Milo's avatar
Fabrizio Milo committed
32
    ):
Leo Gao's avatar
Leo Gao committed
33
        super().__init__()
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        

        # Initialize model
        if isinstance(pretrained, transformers.PreTrainedModel):
            self.gpt2 = pretrained
            self._device = self.gpt2.device

            if tokenizer:
                assert isinstance(
                        tokenizer,
                        transformers.PreTrainedTokenizer
                        ) or isinstance(
                        tokenizer,
                        transformers.PreTrainedTokenizerFast
                        )
                self.tokenizer = tokenizer
            else:
                # Get tokenizer
                model_name = self.gpt2.name_or_path
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                        model_name,
                        revision=revision,
                        trust_remote_code=trust_remote_code,
                        )

                
Leo Gao's avatar
Leo Gao committed
60
        else:
61
62
63
64
            # Initialize device
            assert isinstance(device, str)
            device_list = set(
                ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
Fabrizio Milo's avatar
Fabrizio Milo committed
65
            )
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            if device and device in device_list:
                self._device = torch.device(device)
                print(f"Using device '{device}'")
            else:
                print("Device not specified")
                print(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
            assert isinstance(pretrained, str)

            revision = revision + ("/" + subfolder if subfolder is not None else "")

            self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
                    pretrained,
                    load_in_8bit=load_in_8bit,
                    low_cpu_mem_usage=low_cpu_mem_usage,
                    revision=revision,
                    torch_dtype=_get_dtype(dtype),
                    trust_remote_code=trust_remote_code,
                    ).to(self.device)
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    tokenizer if tokenizer else pretrained,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                    )
94

Leo Gao's avatar
Leo Gao committed
95

96
        self.gpt2.eval()
97

98
        self.vocab_size = self.tokenizer.vocab_size
99

Fabrizio Milo's avatar
Fabrizio Milo committed
100
101
102
103
104
105
106
107
108
        if isinstance(
            self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
        ):
            assert self.tokenizer.encode("hello\n\nhello") == [
                31373,
                198,
                198,
                31373,
            ], self.tokenizer.encode("hello\n\nhello")
Leo Gao's avatar
Leo Gao committed
109

110
111
        # Validate batch_size
        assert isinstance(batch_size, (int, str))
112
        # setup for automatic batch size detection
113
        if batch_size == "auto":
114
115
            self.batch_size_per_gpu = batch_size
        else:
116
            self.batch_size_per_gpu = int(batch_size)
117

118
119
120
121
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
122

123
124
125
126
127
128
129
    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings
130

131
132
133
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
134

135
136
137
138
    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus
Leo Gao's avatar
Leo Gao committed
139

140
141
142
143
    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device
Leo Gao's avatar
Leo Gao committed
144

145
146
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
147

148
149
150
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
151
152
153
154
155
156
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
157
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
158
        """
159
        with torch.no_grad():
160
            return self.gpt2(inps)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
161

162
    def _model_generate(self, context, max_length, eos_token_id):
163
        generation_kwargs = {"do_sample": False, "max_length": max_length}
164
165
        if eos_token_id is not None:
            generation_kwargs['eos_token_id'] = eos_token_id
Nikhil Pinnaparaju's avatar
Nikhil Pinnaparaju committed
166
            generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
167
        return self.gpt2.generate(context, **generation_kwargs)
168
169


170
171
# for backwards compatibility
GPT2LM = HFLM