gpt2.py 3.67 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
3
from typing import Optional
4
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
5
6


7
class HFLM(BaseLM):
Fabrizio Milo's avatar
Fabrizio Milo committed
8
9
10
11
12
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
13
        low_cpu_mem_usage=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
14
15
16
        subfolder=None,
        tokenizer=None,
        batch_size=1,
17
18
        load_in_8bit: Optional[bool] = False,
        trust_remote_code: Optional[bool] = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
19
    ):
Leo Gao's avatar
Leo Gao committed
20
        super().__init__()
21
22
23
24
25

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

26
27
        device_list = set(["cuda", "cpu"] + [f'cuda:{i}' for i in range(torch.cuda.device_count())])
        if device and device in device_list:
researcher2's avatar
researcher2 committed
28
            self._device = torch.device(device)
29
            print(f"Using device '{device}'")
Leo Gao's avatar
Leo Gao committed
30
        else:
Fabrizio Milo's avatar
Fabrizio Milo committed
31
            print("Device not specified")
32
            print(f"Cuda Available? {torch.cuda.is_available()}")
Fabrizio Milo's avatar
Fabrizio Milo committed
33
34
35
36
37
            self._device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
38

39
40
41
        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

42
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
43
44
45
46
47
            pretrained,
            load_in_8bit=load_in_8bit,
            low_cpu_mem_usage=low_cpu_mem_usage,
            revision=revision,
            trust_remote_code=trust_remote_code,
48
        ).to(self.device)
Leo Gao's avatar
Leo Gao committed
49
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
50

51
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
52
            pretrained if tokenizer is None else tokenizer,
53
            revision=revision,
54
            trust_remote_code=trust_remote_code,
Fabrizio Milo's avatar
Fabrizio Milo committed
55
        )
56

57
        self.vocab_size = self.tokenizer.vocab_size
58

59
        # multithreading and batching
60
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
61

Leo Gao's avatar
Leo Gao committed
62
        # TODO: fix multi-gpu
63
        # gpus = torch.cuda.device_count()
Leo Gao's avatar
Leo Gao committed
64
65
        # if gpus > 1:
        #     self.gpt2 = nn.DataParallel(self.gpt2)
66

67
68
69
70
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
71

72
73
74
75
76
77
78
    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings
79

80
81
82
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
83

84
85
86
87
    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus
Leo Gao's avatar
Leo Gao committed
88

89
90
91
92
    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device
Leo Gao's avatar
Leo Gao committed
93

94
95
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
96

97
98
99
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
100
101
102
103
104
105
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
106
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
107
        """
108
        with torch.no_grad():
109
            return self.gpt2(inps)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
110

111
    def _model_generate(self, context, max_length, eos_token_id):
112
113
114
        generation_kwargs = {'do_sample': False, 'max_length': max_length}
        if eos_token_id is not None:
            generation_kwargs['eos_token_id'] = eos_token_id
115
        return self.gpt2.generate(context, pad_token_id=eos_token_id, **generation_kwargs)
116

117
118
# for backwards compatibility
GPT2LM = HFLM