gpt2.py 4.09 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
3
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
4
5


6
class HFLM(BaseLM):
Fabrizio Milo's avatar
Fabrizio Milo committed
7
8
9
10
11
12
13
14
15
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
        subfolder=None,
        tokenizer=None,
        batch_size=1,
    ):
Leo Gao's avatar
Leo Gao committed
16
        super().__init__()
17
18
19
20
21

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

Fabrizio Milo's avatar
Fabrizio Milo committed
22
        if device:
23
24
            if device not in ["cuda", "cpu"]:
                device = int(device)
researcher2's avatar
researcher2 committed
25
            self._device = torch.device(device)
26
            print(f"Using device '{device}'")
Leo Gao's avatar
Leo Gao committed
27
        else:
Fabrizio Milo's avatar
Fabrizio Milo committed
28
            print("Device not specified")
29
            print(f"Cuda Available? {torch.cuda.is_available()}")
Fabrizio Milo's avatar
Fabrizio Milo committed
30
31
32
33
34
            self._device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
35

36
37
38
        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

39
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
40
            pretrained,
41
            revision=revision,
42
        ).to(self.device)
Leo Gao's avatar
Leo Gao committed
43
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
44

45
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
46
            pretrained if tokenizer is None else tokenizer,
47
            revision=revision,
Fabrizio Milo's avatar
Fabrizio Milo committed
48
        )
49

Fabrizio Milo's avatar
Fabrizio Milo committed
50
51
52
53
54
55
56
57
58
        assert isinstance(
            self.tokenizer,
            (
                transformers.GPT2Tokenizer,
                transformers.GPT2TokenizerFast,
                transformers.T5Tokenizer,
                transformers.T5TokenizerFast,
            ),
        ), "this tokenizer has not been checked for compatibility yet!"
59

60
        self.vocab_size = self.tokenizer.vocab_size
61

Fabrizio Milo's avatar
Fabrizio Milo committed
62
63
64
65
66
67
68
69
70
        if isinstance(
            self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
        ):
            assert self.tokenizer.encode("hello\n\nhello") == [
                31373,
                198,
                198,
                31373,
            ], self.tokenizer.encode("hello\n\nhello")
Leo Gao's avatar
Leo Gao committed
71

72
        # multithreading and batching
73
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
74

Leo Gao's avatar
Leo Gao committed
75
        # TODO: fix multi-gpu
76
        # gpus = torch.cuda.device_count()
Leo Gao's avatar
Leo Gao committed
77
78
        # if gpus > 1:
        #     self.gpt2 = nn.DataParallel(self.gpt2)
79

80
81
82
83
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
84

85
86
87
88
89
90
91
    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings
92

93
94
95
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
96

97
98
99
100
    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus
Leo Gao's avatar
Leo Gao committed
101

102
103
104
105
    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device
Leo Gao's avatar
Leo Gao committed
106

107
108
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
109

110
111
112
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
113
114
115
116
117
118
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
119
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
120
        """
121
122
        with torch.no_grad():
            return self.gpt2(inps)[0][:, :, :50257]
Fabrizio Milo's avatar
Fabrizio Milo committed
123

124
125
126
127
128
129
130
131
132
133
    def _model_generate(self, context, max_length, eos_token_id, temperature=0.):
        assert temperature >= 0.
        if temperature == 0.:
            return self.gpt2.generate(
                context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
            )
        else:
            return self.gpt2.generate(
                context, max_length=max_length, eos_token_id=eos_token_id, do_sample=True, temperature=temperature
            )
134
135


136
137
# for backwards compatibility
GPT2LM = HFLM