"configs/seko_talk/L40s/2gpu/seko_talk_bf16.json" did not exist on "56af41ebaf3d5420736be25f96aca06b910a3447"
gpt2.py 4.03 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
bzantium's avatar
bzantium committed
2
3
import transformers
from typing import Optional
4
from lm_eval.base import BaseLM
Jason Phang's avatar
gpt3  
Jason Phang committed
5
6


7
class HFLM(BaseLM):
bzantium's avatar
bzantium committed
8
9
10
11
12
13
14
15
16
17
18
19
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
        low_cpu_mem_usage=None,
        subfolder=None,
        tokenizer=None,
        batch_size=1,
        load_in_8bit: Optional[bool] = False,
        trust_remote_code: Optional[bool] = False,
    ):
20
21
22
23
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
bzantium's avatar
bzantium committed
24
        assert isinstance(batch_size, (int, str))
25

bzantium's avatar
bzantium committed
26
27
28
29
        device_list = set(
            ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
        )
        if device and device in device_list:
30
            self._device = torch.device(device)
bzantium's avatar
bzantium committed
31
            print(f"Using device '{device}'")
32
        else:
bzantium's avatar
bzantium committed
33
34
35
36
37
38
39
            print("Device not specified")
            print(f"Cuda Available? {torch.cuda.is_available()}")
            self._device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
40
41

        # TODO: update this to be less of a hack once subfolder is fixed in HF
bzantium's avatar
bzantium committed
42
43
        revision = revision + ("/" + subfolder if subfolder is not None else "")

44
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
bzantium's avatar
bzantium committed
45
46
47
48
49
            pretrained,
            load_in_8bit=load_in_8bit,
            low_cpu_mem_usage=low_cpu_mem_usage,
            revision=revision,
            trust_remote_code=trust_remote_code,
50
        ).to(self.device)
51
52
        self.gpt2.eval()

53
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
kabbi159's avatar
kabbi159 committed
54
            pretrained if tokenizer is None else tokenizer,
bzantium's avatar
bzantium committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

        self.vocab_size = self.tokenizer.vocab_size

        if isinstance(
            self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
        ):
            assert self.tokenizer.encode("hello\n\nhello") == [
                31373,
                198,
                198,
                31373,
            ], self.tokenizer.encode("hello\n\nhello")

        # setup for automatic batch size detection
        if batch_size == "auto":
            self.batch_size_per_gpu = batch_size
        else:
            self.batch_size_per_gpu = int(batch_size)
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device

104
105
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
soqeue1's avatar
soqeue1 committed
106

107
108
109
110
111
112
113
114
115
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
116
        logits returned from the model
117
118
        """
        with torch.no_grad():
Jiwung Hyun's avatar
Jiwung Hyun committed
119
            return self.gpt2(inps)[0]
soqeue1's avatar
soqeue1 committed
120

121
    def _model_generate(self, context, max_length, eos_token_id):
bzantium's avatar
bzantium committed
122
123
124
125
126
        generation_kwargs = {"do_sample": False, "max_length": max_length}
        if eos_token_id is not None:
            generation_kwargs['eos_token_id'] = eos_token_id
            generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
        return self.gpt2.generate(context, **generation_kwargs)
127
128


129
130
# for backwards compatibility
GPT2LM = HFLM