gpt2.py 3.4 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
3
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
4
5


6
class HFLM(BaseLM):
7
8

    def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
9
10
11
12
13
14
15
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

        if device:
16
            self._device = torch.device(device)
17
        else:
18
            self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
20

        # TODO: update this to be less of a hack once subfolder is fixed in HF
21
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
22
23
            pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "")
        ).to(self.device)
24
25
        self.gpt2.eval()

26
27
        # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
28
            pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
29

30
31
32
33
        assert isinstance(self.tokenizer, (
            transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
            transformers.T5Tokenizer, transformers.T5TokenizerFast,
        )), "this tokenizer has not been checked for compatibility yet!"
34
35
36

        self.vocab_size = self.tokenizer.vocab_size

37
38
39
        if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
            assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
                self.tokenizer.encode('hello\n\nhello')
40
41

        # multithreading and batching
42
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
43
44

        # TODO: fix multi-gpu
45
46
47
        # gpus = torch.cuda.device_count()
        # if gpus > 1:
        #     self.gpt2 = nn.DataParallel(self.gpt2)
Tian Yun's avatar
Tian Yun committed
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device

76
77
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
78
    
79
80
81
82
83
84
85
86
87
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
88
        logits returned from the model
89
90
91
        """
        with torch.no_grad():
            return self.gpt2(inps)[0][:, :, :50257]
92
93
94
95
96
97
98
99
    
    def _model_generate(self, context, max_length, eos_token_id):
        return self.gpt2.generate(
            context,
            max_length=max_length,
            eos_token_id=eos_token_id,
            do_sample=False
        )
cjlovering's avatar
cjlovering committed
100

101

102
103
# for backwards compatibility
GPT2LM = HFLM