import torch import transformers from typing import Optional, Union from lm_eval.base import BaseLM def _get_dtype( dtype: Union[str, torch.dtype] ) -> torch.dtype: """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" if isinstance(dtype, str) and dtype != "auto": # Convert `str` args torch dtype: `float16` -> `torch.float16` _torch_dtype = getattr(torch, dtype) else: _torch_dtype = dtype return _torch_dtype class HFLM(BaseLM): def __init__( self, device="cuda", pretrained="gpt2", revision="main", low_cpu_mem_usage=None, subfolder=None, tokenizer=None, batch_size=1, load_in_8bit: Optional[bool] = False, trust_remote_code: Optional[bool] = False, dtype: Optional[Union[str, torch.dtype]]="auto", ): super().__init__() assert isinstance(device, str) assert isinstance(pretrained, str) assert isinstance(batch_size, (int, str)) device_list = set( ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())] ) if device and device in device_list: self._device = torch.device(device) print(f"Using device '{device}'") else: print("Device not specified") print(f"Cuda Available? {torch.cuda.is_available()}") self._device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) # TODO: update this to be less of a hack once subfolder is fixed in HF revision = revision + ("/" + subfolder if subfolder is not None else "") self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( pretrained, load_in_8bit=load_in_8bit, low_cpu_mem_usage=low_cpu_mem_usage, revision=revision, torch_dtype=_get_dtype(dtype), trust_remote_code=trust_remote_code, ).to(self.device) self.gpt2.eval() self.tokenizer = transformers.AutoTokenizer.from_pretrained( pretrained if tokenizer is None else tokenizer, revision=revision, trust_remote_code=trust_remote_code, ) self.vocab_size = self.tokenizer.vocab_size if isinstance( self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast) ): assert self.tokenizer.encode("hello\n\nhello") == [ 31373, 198, 198, 31373, ], self.tokenizer.encode("hello\n\nhello") # setup for automatic batch size detection if batch_size == "auto": self.batch_size_per_gpu = batch_size else: self.batch_size_per_gpu = int(batch_size) @property def eot_token_id(self): # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* return self.tokenizer.eos_token_id @property def max_length(self): try: return self.gpt2.config.n_ctx except AttributeError: # gptneoconfig doesn't have n_ctx apparently return self.gpt2.config.max_position_embeddings @property def max_gen_toks(self): return 256 @property def batch_size(self): # TODO: fix multi-gpu return self.batch_size_per_gpu # * gpus @property def device(self): # TODO: fix multi-gpu return self._device def tok_encode(self, string: str): return self.tokenizer.encode(string, add_special_tokens=False) def tok_decode(self, tokens): return self.tokenizer.decode(tokens) def _model_call(self, inps): """ inps: a torch tensor of shape [batch, sequence] the size of sequence may vary from call to call returns: a torch tensor of shape [batch, sequence, vocab] with the logits returned from the model """ with torch.no_grad(): return self.gpt2(inps)[0] def _model_generate(self, context, max_length, eos_token_id): generation_kwargs = {"do_sample": False, "max_length": max_length} if eos_token_id is not None: generation_kwargs['eos_token_id'] = eos_token_id generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token return self.gpt2.generate(context, **generation_kwargs) # for backwards compatibility GPT2LM = HFLM