gpt2.py 4.7 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
3
from typing import Optional, Union
4
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
5

6

7
8
9
10
11
12
13
14
15
16
17
18
def _get_dtype(
    dtype: Union[str, torch.dtype]
) -> torch.dtype:
    """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
    if isinstance(dtype, str) and dtype != "auto":
        # Convert `str` args torch dtype: `float16` -> `torch.float16`
        _torch_dtype = getattr(torch, dtype)
    else:
        _torch_dtype = dtype
    return _torch_dtype


19
class HFLM(BaseLM):
20
21
22

    _DEFAULT_MAX_LENGTH = 2048

Fabrizio Milo's avatar
Fabrizio Milo committed
23
24
25
26
27
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
28
        low_cpu_mem_usage=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
29
30
31
        subfolder=None,
        tokenizer=None,
        batch_size=1,
32
	max_length=None,
33
34
        load_in_8bit: Optional[bool] = False,
        trust_remote_code: Optional[bool] = False,
35
        dtype: Optional[Union[str, torch.dtype]]="auto",
Fabrizio Milo's avatar
Fabrizio Milo committed
36
    ):
Leo Gao's avatar
Leo Gao committed
37
        super().__init__()
38
39
40

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
41
        assert isinstance(batch_size, (int, str))
42

43
44
45
        device_list = set(
            ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
        )
46
        if device and device in device_list:
researcher2's avatar
researcher2 committed
47
            self._device = torch.device(device)
48
            print(f"Using device '{device}'")
Leo Gao's avatar
Leo Gao committed
49
        else:
Fabrizio Milo's avatar
Fabrizio Milo committed
50
            print("Device not specified")
51
            print(f"Cuda Available? {torch.cuda.is_available()}")
Fabrizio Milo's avatar
Fabrizio Milo committed
52
53
54
55
56
            self._device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
57

58
59
60
        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

61
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
62
63
64
65
            pretrained,
            load_in_8bit=load_in_8bit,
            low_cpu_mem_usage=low_cpu_mem_usage,
            revision=revision,
66
            torch_dtype=_get_dtype(dtype),
67
            trust_remote_code=trust_remote_code,
68
        ).to(self.device)
Leo Gao's avatar
Leo Gao committed
69
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
70

71
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
72
            pretrained if tokenizer is None else tokenizer,
73
            revision=revision,
74
            trust_remote_code=trust_remote_code,
Fabrizio Milo's avatar
Fabrizio Milo committed
75
        )
76

77
        self.vocab_size = self.tokenizer.vocab_size
78

79
        # setup for automatic batch size detection
80
        if batch_size == "auto":
81
82
            self.batch_size_per_gpu = batch_size
        else:
83
            self.batch_size_per_gpu = int(batch_size)
84

85
86
        self._max_length = max_length

87
88
89
90
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
91

92
93
    @property
    def max_length(self):
94
95
96
97
98
99
100
101
102
103
104
105
        if self._max_length: # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.gpt2.config, attr):
                return getattr(self.gpt2.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH

106

107
108
109
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
110

111
112
113
114
    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus
Leo Gao's avatar
Leo Gao committed
115

116
117
118
119
    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device
Leo Gao's avatar
Leo Gao committed
120

121
122
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
123

124
125
126
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
127
128
129
130
131
132
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
133
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
134
        """
135
        with torch.no_grad():
136
            return self.gpt2(inps)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
137

138
    def _model_generate(self, context, max_length, eos_token_id):
139
        generation_kwargs = {"do_sample": False, "max_length": max_length}
140
141
        if eos_token_id is not None:
            generation_kwargs['eos_token_id'] = eos_token_id
Nikhil Pinnaparaju's avatar
Nikhil Pinnaparaju committed
142
            generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
143
        return self.gpt2.generate(context, **generation_kwargs)
144
145


146
147
# for backwards compatibility
GPT2LM = HFLM