gpt2.py 4.99 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
bzantium's avatar
bzantium committed
2
3
import transformers
from typing import Optional
4
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
5
6


7
class HFLM(BaseLM):
bzantium's avatar
bzantium committed
8
9
10
11
12
13
14
15
16
17
18
19
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
        low_cpu_mem_usage=None,
        subfolder=None,
        tokenizer=None,
        batch_size=1,
        load_in_8bit: Optional[bool] = False,
        trust_remote_code: Optional[bool] = False,
    ):
20
21
22
23
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
bzantium's avatar
bzantium committed
24
        assert isinstance(batch_size, (int, str))
25

bzantium's avatar
bzantium committed
26
27
28
29
        device_list = set(
            ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
        )
        if device and device in device_list:
30
            self._device = torch.device(device)
bzantium's avatar
bzantium committed
31
            print(f"Using device '{device}'")
32
        else:
bzantium's avatar
bzantium committed
33
34
35
36
37
38
39
            print("Device not specified")
            print(f"Cuda Available? {torch.cuda.is_available()}")
            self._device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
40
41

        # TODO: update this to be less of a hack once subfolder is fixed in HF
bzantium's avatar
bzantium committed
42
43
        revision = revision + ("/" + subfolder if subfolder is not None else "")

44
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
bzantium's avatar
bzantium committed
45
46
47
48
49
            pretrained,
            load_in_8bit=load_in_8bit,
            low_cpu_mem_usage=low_cpu_mem_usage,
            revision=revision,
            trust_remote_code=trust_remote_code,
50
        ).to(self.device)
51
52
        self.gpt2.eval()

53
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
kabbi159's avatar
kabbi159 committed
54
            pretrained if tokenizer is None else tokenizer,
bzantium's avatar
bzantium committed
55
<<<<<<< HEAD
kabbi159's avatar
kabbi159 committed
56
            revision=revision + ("/" + subfolder if subfolder is not None else ""))
57

soqeue1's avatar
soqeue1 committed
58
59
60
61
        # assert isinstance(self.tokenizer, (
        #     transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
        #     transformers.T5Tokenizer, transformers.T5TokenizerFast,
        # )), "this tokenizer has not been checked for compatibility yet!"
62
63
64

        self.vocab_size = self.tokenizer.vocab_size

soqeue1's avatar
soqeue1 committed
65
66
67
        # if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
        #     assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
        #         self.tokenizer.encode('hello\n\nhello')
68
69

        # multithreading and batching
70
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
71
72

        # TODO: fix multi-gpu
73
        # gpus = torch.cuda.device_count()
74
75
        # if gpus > 1:
        #     self.gpt2 = nn.DataParallel(self.gpt2)
bzantium's avatar
bzantium committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
=======
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

        self.vocab_size = self.tokenizer.vocab_size

        if isinstance(
            self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
        ):
            assert self.tokenizer.encode("hello\n\nhello") == [
                31373,
                198,
                198,
                31373,
            ], self.tokenizer.encode("hello\n\nhello")

        # setup for automatic batch size detection
        if batch_size == "auto":
            self.batch_size_per_gpu = batch_size
        else:
            self.batch_size_per_gpu = int(batch_size)
>>>>>>> d145167959c2b1826d900524912cb99c44d5fb30
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device

127
128
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
soqeue1's avatar
soqeue1 committed
129

130
131
132
133
134
135
136
137
138
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
139
        logits returned from the model
140
141
        """
        with torch.no_grad():
Jiwung Hyun's avatar
Jiwung Hyun committed
142
            return self.gpt2(inps)[0]
soqeue1's avatar
soqeue1 committed
143

144
    def _model_generate(self, context, max_length, eos_token_id):
bzantium's avatar
bzantium committed
145
146
147
148
149
        generation_kwargs = {"do_sample": False, "max_length": max_length}
        if eos_token_id is not None:
            generation_kwargs['eos_token_id'] = eos_token_id
            generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
        return self.gpt2.generate(context, **generation_kwargs)
150
151


152
153
# for backwards compatibility
GPT2LM = HFLM