import copy import os from pathlib import Path from typing import List, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F import transformers from accelerate import Accelerator, DistributedType, find_executable_batch_size from packaging import version from peft import PeftModel from peft import __version__ as PEFT_VERSION from tqdm import tqdm from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, ) from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.model import LM from lm_eval.api.registry import register_model from lm_eval.utils import Collator, stop_sequences_criteria eval_logger = utils.eval_logger def _get_accelerate_args( device_map_option: Optional[str] = "auto", max_memory_per_gpu: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None, offload_folder: Optional[str] = "./offload", ) -> dict: """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" max_memory = {} if max_memory_per_gpu is not None: max_memory_per_gpu_map = { device_idx: max_memory_per_gpu for device_idx in range(torch.cuda.device_count()) } max_memory.update(max_memory_per_gpu_map) if max_cpu_memory is not None: max_memory["cpu"] = max_cpu_memory args = {} if max_memory: args["max_memory"] = max_memory args["device_map"] = device_map_option args["offload_folder"] = offload_folder return args @register_model("hf-auto", "hf", "huggingface") class HFLM(LM): """ An abstracted Huggingface model class. Enables usage with both models of `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. Supports data-parallel multi-GPU with HF Accelerate. """ AUTO_MODEL_CLASS = None _DEFAULT_MAX_LENGTH = 2048 def __init__( self, pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2", backend: Optional[ Literal["default", "causal", "seq2seq"] ] = "default", # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) revision: Optional[str] = "main", subfolder: Optional[str] = None, tokenizer: Optional[ Union[ str, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast, ] ] = None, truncation: Optional[bool] = False, max_length: Optional[int] = None, device: Optional[str] = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", batch_size: Optional[Union[int, str]] = 1, max_batch_size: Optional[int] = 64, trust_remote_code: Optional[bool] = False, use_fast_tokenizer: Optional[bool] = True, # arguments used for splitting a model across GPUs naively. # only used if `parallelize=True`. parallelize: Optional[bool] = False, device_map_option: Optional[str] = "auto", max_memory_per_gpu: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None, offload_folder: Optional[Union[str, os.PathLike]] = "./offload", # PEFT and quantization options peft: Optional[str] = None, autogptq: Optional[Union[bool, str]] = False, **kwargs, ) -> None: super().__init__() # optionally: take in an already-initialized transformers.PreTrainedModel if not isinstance(pretrained, str): eval_logger.warning( "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." ) assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" self._model = pretrained self._device = self._model.device self._config = self._model.config if tokenizer: assert isinstance( tokenizer, transformers.PreTrainedTokenizer ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) self.tokenizer = tokenizer else: # Get tokenizer model_name = self._model.name_or_path self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_name, revision=revision, trust_remote_code=trust_remote_code, use_fast=use_fast_tokenizer, ) else: assert isinstance(device, str) assert isinstance(pretrained, str) assert isinstance(batch_size, (int, str)) gpus = torch.cuda.device_count() accelerator = Accelerator() if not (parallelize or accelerator.num_processes > 1): # use user-passed device device_list = set( ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())] + ["mps", "mps:0"] ) if device and device in device_list: self._device = torch.device(device) eval_logger.info(f"Using device '{device}'") if device in ("mps", "mps:0") and version.parse( torch.__version__ ) < version.parse("2.1"): raise RuntimeError( f"mps requires torch >= 2.1. You have {torch.__version__}" ) else: eval_logger.info("Device not specified") eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") self._device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) else: if device != "cuda": eval_logger.info( f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." ) # TODO: include in warning that `load_in_8bit` etc. affect this too self._device = torch.device(device) # TODO: update this to be less of a hack once subfolder is fixed in HF revision = revision + ("/" + subfolder if subfolder is not None else "") self._get_config( pretrained, revision=revision, trust_remote_code=trust_remote_code, ) # determine which of 'causal' and 'seq2seq' backends to use self._get_backend( config=self.config, backend=backend, trust_remote_code=trust_remote_code ) # if we passed `pretrained` as a string, initialize our model now if isinstance(pretrained, str): self._create_model( pretrained=pretrained, revision=revision, dtype=dtype, trust_remote_code=trust_remote_code, parallelize=parallelize, device_map_option=device_map_option, max_memory_per_gpu=max_memory_per_gpu, max_cpu_memory=max_cpu_memory, offload_folder=offload_folder, peft=peft, autogptq=autogptq, **kwargs, ) # access self._model through self.model property outside this method self.model.eval() self.model.tie_weights() if isinstance(pretrained, str) and (gpus >= 1 or str(self.device) == "mps"): if not (parallelize or autogptq or ("device_map" in kwargs)): # place model onto device requested manually, # if not using HF Accelerate or device_map # or any other option that preloads model onto device try: self.model.to(self.device) except ValueError: eval_logger.info( "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." ) self._create_tokenizer( pretrained, tokenizer, revision=revision, trust_remote_code=trust_remote_code, use_fast_tokenizer=use_fast_tokenizer, ) self.truncation = truncation self.vocab_size = self.tokenizer.vocab_size # select (or create) a pad token to use if self.tokenizer.pad_token: pass elif self.tokenizer.unk_token: self.tokenizer.pad_token_id = self.tokenizer.unk_token_id elif self.tokenizer.eos_token: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id else: if self.config.model_type == "qwen": # Qwen's trust_remote_code tokenizer does not allow for adding special tokens self.tokenizer.pad_token = "<|endoftext|>" else: self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) self._max_length = max_length self.batch_schedule = 1 self.batch_sizes = {} self.max_batch_size = max_batch_size if str(batch_size).startswith("auto"): batch_size = batch_size.split(":") self.batch_size_per_gpu = batch_size[0] self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 else: self.batch_size_per_gpu = int(batch_size) if isinstance(pretrained, str): # multigpu data-parallel support when launched with accelerate if gpus > 1: if parallelize: if accelerator.num_processes > 1: raise RuntimeError( "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher." ) else: pass elif accelerator.num_processes == 1: # if we aren't launching via accelerate, ditch self._rank = 0 self._world_size = 1 else: if gpus > accelerator.num_processes: eval_logger.warning( "WARNING: The number of total system GPUs does not match the number of spawned processes. " "If you would like to use data parallelism, please launch the script " "with 'accelerate launch *script*'. " f"Current run will proceed with {accelerator.num_processes} devices." ) assert ( accelerator.distributed_type in [ DistributedType.FSDP, DistributedType.MULTI_GPU, ] ), "Unsupported distributed type provided. Only DDP and FSDP are supported." if accelerator.distributed_type == DistributedType.FSDP: self._model = accelerator.prepare(self.model) else: self._model = accelerator.prepare_model( self.model, evaluation_mode=True ) self._device = torch.device( f"cuda:{accelerator.local_process_index}" ) self.accelerator = accelerator if self.accelerator.is_local_main_process: eval_logger.info(f"Using {gpus} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes else: # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. eval_logger.warning( "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" ) self._rank = 0 self._world_size = 1 @property def config(self): # return the associated transformers.AutoConfig for the given pretrained model. return self._config @property def model(self): # returns the model, unwrapping it if using Accelerate if hasattr(self, "accelerator"): return self.accelerator.unwrap_model(self._model) else: return self._model @property def eot_token_id(self): # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* return self.tokenizer.eos_token_id @property def max_length(self): if self._max_length: # if max length manually set, return it return self._max_length seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") for attr in seqlen_config_attrs: if hasattr(self.model.config, attr): return getattr(self.model.config, attr) if hasattr(self.tokenizer, "model_max_length"): if self.tokenizer.model_max_length == 1000000000000000019884624838656: return self._DEFAULT_MAX_LENGTH return self.tokenizer.model_max_length return self._DEFAULT_MAX_LENGTH @property def max_gen_toks(self) -> int: return 256 @property def batch_size(self): return self.batch_size_per_gpu @property def device(self): return self._device @property def rank(self): return self._rank @property def world_size(self): return self._world_size def _get_backend( self, config: transformers.AutoConfig, backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", trust_remote_code: Optional[bool] = False, ) -> None: """ Helper method during initialization. Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used. """ assert backend in ["default", "causal", "seq2seq"] if backend != "default": # if we've settled on non-default backend, use that manually if backend == "causal": self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM elif backend == "seq2seq": self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM eval_logger.info( f"Overrode HF model backend type, and using type '{backend}'" ) else: # determine and use the default HF backend for this model, based on its config + metadata. if ( getattr(config, "model_type") in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES ): # first check if model type is listed under seq2seq models, since some # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. # these special cases should be treated as seq2seq models. self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM elif ( getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ): self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM else: if not trust_remote_code: eval_logger.warning( "HF model type is neither marked as CausalLM or Seq2SeqLM. \ This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." ) # if model type is neither in HF transformers causal or seq2seq model registries # then we default to AutoModelForCausalLM self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM assert self.AUTO_MODEL_CLASS in [ transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM, ] return None def _get_config( self, pretrained: str, revision: str = "main", trust_remote_code: bool = False, ) -> None: self._config = transformers.AutoConfig.from_pretrained( pretrained, revision=revision, trust_remote_code=trust_remote_code, ) def _create_model( self, pretrained: str, revision: Optional[str] = "main", dtype: Optional[Union[str, torch.dtype]] = "auto", trust_remote_code: Optional[bool] = False, # arguments used for splitting a model across GPUs naively. # only used if `parallelize=True`. # (accelerate naive PP (device_map) options) parallelize: Optional[bool] = False, device_map_option: Optional[str] = "auto", max_memory_per_gpu: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None, offload_folder: Optional[str] = "./offload", # PEFT and quantization options peft: Optional[str] = None, autogptq: Optional[Union[bool, str]] = False, **kwargs, ) -> None: """ Initializes an HF or HF-compatible PreTrainedModel from scratch inside HFLM, using the kwargs passed into self.__init__(). Also handles functionality such as AutoGPTQ usage and PEFT wrapping. For future similar extensions to AutoGPTQ that are not core to HF's ecosystem, (such as PyTorch models that are nearly, but not quite, fully mirroring HF's public interface relied on in this HFLM class) please consider subclassing HFLM and overriding this and other methods as needed. """ model_kwargs = kwargs if kwargs else {} if parallelize: model_kwargs.update( _get_accelerate_args( device_map_option, max_memory_per_gpu, max_cpu_memory, offload_folder, ) ) if not autogptq: if model_kwargs.get("load_in_4bit", None): assert ( transformers.__version__ >= "4.30.0" ), "load_in_4bit requires transformers >= 4.30.0" if transformers.__version__ >= "4.30.0": if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("bnb_4bit_compute_dtype", None): model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype( model_kwargs["bnb_4bit_compute_dtype"] ) self._model = self.AUTO_MODEL_CLASS.from_pretrained( pretrained, revision=revision, torch_dtype=utils.get_dtype(dtype), trust_remote_code=trust_remote_code, **model_kwargs, ) else: try: from auto_gptq import AutoGPTQForCausalLM except ModuleNotFoundError: raise Exception( "Tried to load auto_gptq, but auto-gptq is not installed ", "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", ) self._model = AutoGPTQForCausalLM.from_quantized( pretrained, trust_remote_code=trust_remote_code, model_basename=None if autogptq is True else Path(autogptq).stem, use_safetensors=True if autogptq is True else autogptq.endswith(".safetensors"), **model_kwargs, ) if peft: if model_kwargs.get("load_in_4bit", None): assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0" self._model = PeftModel.from_pretrained( self._model, peft, revision=revision ) return None def _create_tokenizer( self, pretrained: Union[str, transformers.PreTrainedModel], tokenizer: Optional[ Union[ str, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast, ] ], revision: Optional[str] = "main", trust_remote_code: Optional[bool] = False, use_fast_tokenizer: Optional[bool] = True, ) -> None: """ Helper method during initialization. Create a tokenizer object corresponding to the correct tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. """ if tokenizer: if isinstance(tokenizer, str): self.tokenizer = transformers.AutoTokenizer.from_pretrained( tokenizer, revision=revision, trust_remote_code=trust_remote_code, use_fast=use_fast_tokenizer, ) else: assert isinstance( tokenizer, transformers.PreTrainedTokenizer ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) self.tokenizer = tokenizer else: # Get tokenizer based on 'pretrained' if isinstance(pretrained, str): model_name = pretrained else: # get the HF hub name via accessor on model model_name = self.model.name_or_path self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_name, revision=revision, trust_remote_code=trust_remote_code, use_fast=use_fast_tokenizer, ) return None def _detect_batch_size(self, requests=None, pos: int = 0): if requests: _, context_enc, continuation_enc = requests[pos] max_length = len( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] ) max_context_enc = len(context_enc[-(self.max_length + 1) :]) max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) else: max_length = self.max_length # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=self.max_batch_size) def forward_batch(batch_size): if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: length = max(max_context_enc, max_cont_enc) batched_conts = torch.ones( (batch_size, length), device=self.device ).long() test_batch = torch.ones((batch_size, length), device=self.device).long() call_kwargs = { "attn_mask": test_batch, "labels": batched_conts, } else: call_kwargs = {} test_batch = torch.ones( (batch_size, max_length), device=self.device ).long() for _ in range(5): out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) out = out # Identity process so that it passes pre-commit return batch_size batch_size = forward_batch() if self.world_size > 1: # if multi-GPU, always take minimum over all selected batch sizes max_rnk_bs = torch.tensor([batch_size], device=self.device) gathered = ( self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() ) batch_size = min(gathered) utils.clear_torch_cache() return batch_size utils.clear_torch_cache() return batch_size def tok_encode( self, string: str, left_truncate_len=None, add_special_tokens=None ) -> List[int]: """ """ if add_special_tokens is None: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: add_special_tokens = False elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: add_special_tokens = True encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) # left-truncate the encoded context to be at most `left_truncate_len` tokens long if left_truncate_len: encoding = encoding[-left_truncate_len:] return encoding def tok_batch_encode( self, strings: List[str], padding_side: str = "left", left_truncate_len: int = None, truncation: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. old_padding_side = self.tokenizer.padding_side self.tokenizer.padding_side = padding_side if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: add_special_tokens = False elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: add_special_tokens = True encoding = self.tokenizer( strings, truncation=truncation, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens, ) if left_truncate_len: encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["attention_mask"] = encoding["attention_mask"][ :, -left_truncate_len: ] self.tokenizer.padding_side = old_padding_side return encoding["input_ids"], encoding["attention_mask"] def tok_decode(self, tokens): if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: return self.tokenizer.decode(tokens) elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: return self.tokenizer.decode(tokens, skip_special_tokens=True) def _model_call(self, inps, attn_mask=None, labels=None): """ :param inps: torch.Tensor A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape [batch, sequence_ctx]. the size of sequence may vary from call to call :param attn_mask: torch.Tensor, optional A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM :param labels: torch.Tensor, optional A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM :return A torch tensor of shape [batch, sequence, vocab] with the logits returned from the model's decoder """ with torch.no_grad(): if attn_mask is not None or labels is not None: assert attn_mask is not None and labels is not None assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM return self.model( input_ids=inps, attention_mask=attn_mask, labels=labels ).logits else: assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM return self.model(inps).logits def _model_generate(self, context, max_length, stop, **generation_kwargs): # we require users to pass do_sample=True explicitly # for non-greedy gen. This should be reevaluated when considering beam search. if "do_sample" not in generation_kwargs: generation_kwargs["do_sample"] = False # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, 1, context.shape[0] ) return self.model.generate( input_ids=context, max_length=max_length, stopping_criteria=stopping_criteria, pad_token_id=self.tokenizer.pad_token_id, use_cache=True, **generation_kwargs, ) def _select_cont_toks(self, logits, contlen=None, inplen=None): if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: assert ( contlen and inplen ), "Must pass input len and cont. len to select scored logits for causal LM" # discard right-padding. # also discard the input/context tokens. we'll only score continuations. logits = logits[inplen - contlen : inplen] elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: assert ( contlen and not inplen ), "Selecting scored logits for Seq2SeqLM requires only cont. len" # only discard right-padding. # the logits input to this fn only contain decoder-side tokens. logits = logits[:contlen] return logits def _encode_pair( self, context: str, continuation: str ) -> Tuple[List[int], List[int]]: n_spaces = len(context) - len(context.rstrip()) if n_spaces > 0: continuation = context[-n_spaces:] + continuation context = context[:-n_spaces] whole_enc = self.tok_encode(context + continuation, add_special_tokens=False) context_enc = self.tok_encode(context, add_special_tokens=False) # whole_enc = self.tok_encode(context + continuation) # context_enc = self.tok_encode(context, add_special_tokens=False) context_enc_len = len(context_enc) continuation_enc = whole_enc[context_enc_len:] return context_enc, continuation_enc def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: new_reqs = [] for context, continuation in [req.args for req in requests]: if context == "": # end of text as context context_enc, continuation_enc = ( [self.eot_token_id], self.tok_encode(continuation), ) else: context_enc, continuation_enc = self._encode_pair(context, continuation) new_reqs.append(((context, continuation), context_enc, continuation_enc)) return self._loglikelihood_tokens(new_reqs) def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: loglikelihoods = [] adaptive_batch_size = None if self.batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() print(f"Determined Largest batch size: {batch_size}") adaptive_batch_size = batch_size for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): rolling_token_windows = list( map( utils.make_disjoint_window, utils.get_rolling_token_windows( token_list=self.tok_encode(string), prefix_token=self.eot_token_id, max_seq_len=self.max_length, context_len=1, ), ) ) # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case rolling_token_windows = [(None,) + x for x in rolling_token_windows] pad_amnt = 0 if self.world_size > 1: # We pad out the external document-level iterator so the inner iterator doesn't hang mytensor = torch.tensor(len(rolling_token_windows), device=self.device) gathered = ( self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() ) pad_amnt = max(gathered) - gathered[self.rank] if pad_amnt > 0: rolling_token_windows += pad_amnt * [rolling_token_windows[0]] string_nll = self._loglikelihood_tokens( rolling_token_windows, disable_tqdm=True, override_bs=adaptive_batch_size, ) if (self.world_size > 1) and (pad_amnt > 0): string_nll = [x[0] for x in string_nll[:-pad_amnt]] else: # discard is_greedy string_nll = [x[0] for x in string_nll] string_nll = sum(string_nll) loglikelihoods.append(string_nll) return loglikelihoods def _batch_scheduler(self, pos, n_reordered_requests): sched = pos // int(len(n_reordered_requests) / self.batch_schedule) if sched in self.batch_sizes: return self.batch_sizes[sched] if (len(self.batch_sizes) > 1) and ( self.batch_sizes[sched - 1] == self.max_batch_size ): # if previous batch size is already maximal, skip recomputation self.batch_sizes[sched] = self.max_batch_size return self.batch_sizes[sched] print( f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" ) self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) print(f"Determined largest batch size: {self.batch_sizes[sched]}") return self.batch_sizes[sched] def _loglikelihood_tokens( self, requests: List[Tuple[Tuple[str, str], List[int], List[int]]], disable_tqdm: bool = False, override_bs: int = None, ) -> List[Tuple[float, bool]]: # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] def _collate(x): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = x[1] + x[2] return -len(toks), tuple(toks) re_ord = Collator(requests, sort_fn=_collate) # automatic (variable) batch size detection for vectorization # pull longest context sample from request n_reordered_requests = len(re_ord) batch_size = ( self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None ) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) for chunk in chunks: inps = [] cont_toks_list = [] inplens = [] conts = [] encoder_attns = [] padding_len_inp = None padding_len_cont = None # because vectorizing is annoying, we first convert each (context, continuation) pair to padded # tensors, then we pack them together into a batch, call the model, and then pick it all apart # again because vectorizing is annoying for _, context_enc, continuation_enc in chunk: # sanity check assert len(context_enc) > 0 assert len(continuation_enc) > 0 assert len(continuation_enc) <= self.max_length # how this all works (illustrated on a causal decoder-only setup): # CTX CONT # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # model \ \ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: inp = torch.tensor( (context_enc)[-self.max_length :], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape # build encoder attn masks encoder_attns.append(torch.ones_like(inp)) cont = torch.tensor( (continuation_enc)[-self.max_length :], # TODO: left-shift these? # TODO: our code assumes we never end up truncating conts for either model type dtype=torch.long, device=self.device, ) (contlen,) = cont.shape conts.append(cont) padding_len_cont = ( max(padding_len_cont, contlen) if padding_len_cont is not None else contlen ) padding_len_inp = ( max(padding_len_inp, inplen) if padding_len_inp is not None else inplen ) inps.append(inp) # [1, inp_length] cont_toks_list.append(continuation_enc) inplens.append(inplen) # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: batched_inps = utils.pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # TODO: left-pad encoder inps and mask? batched_inps = utils.pad_and_concat( padding_len_inp, inps ) # [batch, padding_len_inp] batched_conts = utils.pad_and_concat( padding_len_cont, conts ) # [batch, padding_len_cont] batched_encoder_mask = utils.pad_and_concat( padding_len_inp, encoder_attns ) # [batch, padding_len_inp] call_kwargs = { "attn_mask": batched_encoder_mask, "labels": batched_conts, } multi_logits = F.log_softmax( self._model_call(batched_inps, **call_kwargs), dim=-1 ) # [batch, padding_length (inp or cont), vocab] for (cache_key, _, _), logits, inplen, cont_toks in zip( chunk, multi_logits, inplens, cont_toks_list ): # Slice to original seq length contlen = len(cont_toks) # take only logits in the continuation # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable ctx_len = ( inplen + (logits.shape[0] - padding_len_inp) if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None ) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) cont_toks = torch.tensor( cont_toks, dtype=torch.long, device=self.device ).unsqueeze(0) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( -1 ) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) res.append(answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer) pbar.update(1) pbar.close() return re_ord.get_original(res) def generate_until(self, requests: List[Instance]) -> List[str]: res = [] def _collate(x): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tok_encode(x[0]) return -len(toks), x[0] pbar = tqdm(total=len(requests), disable=(self.rank != 0)) if self.batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() print(f"Determined Largest batch size: {batch_size}") adaptive_batch_size = batch_size # for each different set of kwargs, we execute all requests, by batch. batch_size = ( self.batch_size if self.batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None ) # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. re_ords = Collator([reg.args for reg in requests], _collate, grouping=True) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) for chunk in chunks: contexts, all_gen_kwargs = zip(*chunk) # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. until = None if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "until" in kwargs.keys(): until = kwargs.pop("until") if isinstance(until, str): until = [kwargs] elif not isinstance(until, list): raise ValueError( f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" ) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {kwargs}" ) if not until: until = [self.tok_decode(self.eot_token_id)] if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # set the max length in tokens of inputs ("context_enc") if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length # encode, pad, and truncate contexts for this batch context_enc, attn_masks = self.tok_batch_encode( contexts, left_truncate_len=max_ctx_len, truncation=self.truncation, ) context_enc = context_enc.to(self.device) attn_masks = attn_masks.to(self.device) if "max_length" not in kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks # perform batched generation cont = self._model_generate( context=context_enc, attention_mask=attn_masks, stop=until, **kwargs, ) cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: cont_toks = cont_toks[context_enc.shape[1] :] s = self.tok_decode(cont_toks) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' s = s.split(term)[0] res.append(s) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res