gpt2.py 6.1 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
3
from typing import Optional, Union
4
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
5

6

7
8
9
10
11
12
13
14
15
16
17
18
def _get_dtype(
    dtype: Union[str, torch.dtype]
) -> torch.dtype:
    """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
    if isinstance(dtype, str) and dtype != "auto":
        # Convert `str` args torch dtype: `float16` -> `torch.float16`
        _torch_dtype = getattr(torch, dtype)
    else:
        _torch_dtype = dtype
    return _torch_dtype


19
class HFLM(BaseLM):
20
21
22

    _DEFAULT_MAX_LENGTH = 2048

Fabrizio Milo's avatar
Fabrizio Milo committed
23
24
25
26
27
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
28
        low_cpu_mem_usage=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
29
30
31
        subfolder=None,
        tokenizer=None,
        batch_size=1,
32
        max_batch_size=512,
33
        max_length=None,
34
35
        load_in_8bit: Optional[bool] = False,
        trust_remote_code: Optional[bool] = False,
36
        dtype: Optional[Union[str, torch.dtype]]="auto",
Fabrizio Milo's avatar
Fabrizio Milo committed
37
    ):
Leo Gao's avatar
Leo Gao committed
38
        super().__init__()
39

40
41
42

        # Initialize model
        if isinstance(pretrained, transformers.PreTrainedModel):
43
44
            self.model = pretrained
            self._device = self.model.device
45
46
47
48
49
50
51
52
53
54
55
56

            if tokenizer:
                assert isinstance(
                        tokenizer,
                        transformers.PreTrainedTokenizer
                        ) or isinstance(
                        tokenizer,
                        transformers.PreTrainedTokenizerFast
                        )
                self.tokenizer = tokenizer
            else:
                # Get tokenizer
57
                model_name = self.model.name_or_path
58
59
60
61
62
63
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                        model_name,
                        revision=revision,
                        trust_remote_code=trust_remote_code,
                        )

64
65
        elif isinstance(pretrained, str):

66
67
68
69
            # Initialize device
            assert isinstance(device, str)
            device_list = set(
                ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
Fabrizio Milo's avatar
Fabrizio Milo committed
70
            )
71
72
73
74
75
76
77
78
79
80
81
82
83
            if device and device in device_list:
                self._device = torch.device(device)
                print(f"Using device '{device}'")
            else:
                print("Device not specified")
                print(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
            revision = revision + ("/" + subfolder if subfolder is not None else "")

84
            # Initialize new model and tokenizer instances
85
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
86
87
88
89
90
91
92
93
94
95
96
97
                    pretrained,
                    load_in_8bit=load_in_8bit,
                    low_cpu_mem_usage=low_cpu_mem_usage,
                    revision=revision,
                    torch_dtype=_get_dtype(dtype),
                    trust_remote_code=trust_remote_code,
                    ).to(self.device)
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    tokenizer if tokenizer else pretrained,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                    )
98

99
100
        else:
            raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel')
Leo Gao's avatar
Leo Gao committed
101

102
        self.model.eval()
103

104
        self.vocab_size = self.tokenizer.vocab_size
105

106
107
        # Validate batch_size
        assert isinstance(batch_size, (int, str))
108

109
        # setup for automatic batch size detection
110
111
112
113
        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
114
        else:
115
            self.batch_size_per_gpu = int(batch_size)
116
        self.max_batch_size = max_batch_size
117

118
119
        self._max_length = max_length

120
121
122
123
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
124

125
126
    @property
    def max_length(self):
127
128
129
130
        if self._max_length: # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
131
132
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
133
134
135
136
137
138
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH

139

140
141
142
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
143

144
145
146
147
    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus
Leo Gao's avatar
Leo Gao committed
148

149
150
151
152
    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device
Leo Gao's avatar
Leo Gao committed
153

154
155
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
156

157
158
159
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
160
161
162
163
164
165
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
166
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
167
        """
168
        with torch.no_grad():
169
            return self.model(inps)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
170

171
    def _model_generate(self, context, max_length, eos_token_id):
172
        generation_kwargs = {"do_sample": False, "max_length": max_length}
173
174
        if eos_token_id is not None:
            generation_kwargs['eos_token_id'] = eos_token_id
Nikhil Pinnaparaju's avatar
Nikhil Pinnaparaju committed
175
            generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
176
        return self.model.generate(context, **generation_kwargs)
177
178


179
180
# for backwards compatibility
GPT2LM = HFLM