import os import torch import transformers from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, ) from peft import __version__ as PEFT_VERSION, PeftModel import copy from collections import defaultdict from tqdm import tqdm from pathlib import Path import torch.nn.functional as F from lm_eval import utils from lm_eval.logger import eval_logger from lm_eval.api.model import LM from lm_eval.api.registry import register_model from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from accelerate import Accelerator, find_executable_batch_size, DistributedType from typing import List, Optional, Union def _get_accelerate_args( device_map_option: Optional[str] = "auto", max_memory_per_gpu: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None, offload_folder: Optional[str] = "./offload", ) -> dict: """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" max_memory = {} if max_memory_per_gpu is not None: max_memory_per_gpu_map = { device_idx: max_memory_per_gpu for device_idx in range(torch.cuda.device_count()) } max_memory.update(max_memory_per_gpu_map) if max_cpu_memory is not None: max_memory["cpu"] = max_cpu_memory args = {} if max_memory: args["max_memory"] = max_memory args["device_map"] = device_map_option args["offload_folder"] = offload_folder return args @register_model("hf-auto", "hf", "huggingface") class HFLM(LM): """ An abstracted Huggingface model class. Enables usage with both models of `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. Supports data-parallel multi-GPU with HF Accelerate. """ AUTO_MODEL_CLASS = None _DEFAULT_MAX_LENGTH = 2048 def __init__( self, pretrained: Optional[str] = "gpt2", revision: Optional[str] = "main", subfolder: Optional[str] = None, tokenizer: Optional[str] = None, truncation: Optional[bool] = False, max_length: Optional[int] = None, device: Optional[str] = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", batch_size: Optional[Union[int, str]] = 1, max_batch_size: Optional[int] = 64, low_cpu_mem_usage: Optional[bool] = True, trust_remote_code: Optional[bool] = False, use_fast_tokenizer: Optional[bool] = True, cache_dir: Optional[Union[str, os.PathLike]] = None, # arguments used for splitting a model across GPUs naively. # only used if `parallelize=True`. parallelize: Optional[bool] = False, device_map_option: Optional[str] = "auto", max_memory_per_gpu: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None, offload_folder: Optional[str] = "./offload", # PEFT and quantization options peft: Optional[str] = None, load_in_8bit: Optional[bool] = False, load_in_4bit: Optional[bool] = False, bnb_4bit_quant_type: Optional[str] = None, bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, gptq: Optional[Union[bool, str]] = False, gptq_use_triton: Optional[bool] = False, ) -> None: super().__init__() assert isinstance(device, str) assert isinstance(pretrained, str) assert isinstance(batch_size, (int, str)) gpus = torch.cuda.device_count() accelerator = Accelerator() if not (parallelize or accelerator.num_processes > 1): # use user-passed device device_list = set( ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())] + ["mps", "mps:0"] ) if device: if device not in device_list: device = int(device) self._device = torch.device(device) eval_logger.info(f"Using device '{device}'") if device in ("mps", "mps:0") and "dev" not in torch.__version__: eval_logger.info( "MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of " "PyTorch: pip3 install --pre torch torchvision torchaudio --index-url " "https://download.pytorch.org/whl/nightly/cpu" ) else: eval_logger.info("Device not specified") eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") self._device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) else: if device != "cuda": eval_logger.info( f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." ) # TODO: include in warning that `load_in_8bit` etc. affect this too self._device = device model_kwargs = {} if parallelize: model_kwargs = _get_accelerate_args( device_map_option, max_memory_per_gpu, max_cpu_memory, offload_folder, ) # TODO: update this to be less of a hack once subfolder is fixed in HF revision = revision + ("/" + subfolder if subfolder is not None else "") self._config = transformers.AutoConfig.from_pretrained( pretrained, revision=revision, trust_remote_code=trust_remote_code, ) if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM elif ( not getattr(self._config, "model_type") in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES ): if not trust_remote_code: eval_logger.warning( "HF model type is neither marked as CausalLM or Seq2SeqLM. \ This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." ) # if model type is neither in HF transformers causal or seq2seq model registries # then we default to AutoModelForCausalLM self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM else: self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM assert self.AUTO_MODEL_CLASS in [ transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM, ] if not gptq: if load_in_4bit: assert ( transformers.__version__ >= "4.30.0" ), "load_in_4bit requires transformers >= 4.30.0" if transformers.__version__ >= "4.30.0": model_kwargs["load_in_4bit"] = load_in_4bit if load_in_4bit: if bnb_4bit_quant_type: model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type if bnb_4bit_compute_dtype: model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype( bnb_4bit_compute_dtype ) self._model = self.AUTO_MODEL_CLASS.from_pretrained( pretrained, revision=revision, torch_dtype=utils.get_dtype(dtype), low_cpu_mem_usage=low_cpu_mem_usage, trust_remote_code=trust_remote_code, load_in_8bit=load_in_8bit, **model_kwargs, ) else: try: from auto_gptq import AutoGPTQForCausalLM except ModuleNotFoundError: raise Exception( "Tried to load auto_gptq, but auto-gptq is not installed ", "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", ) self._model = AutoGPTQForCausalLM.from_quantized( pretrained, model_basename=None if gptq is True else Path(gptq).stem, low_cpu_mem_usage=low_cpu_mem_usage, trust_remote_code=trust_remote_code, use_safetensors=True if gptq is True else gptq.endswith(".safetensors"), use_triton=gptq_use_triton, warmup_triton=gptq_use_triton, **model_kwargs, ) if peft: if load_in_4bit: assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0" self._model = PeftModel.from_pretrained( self._model, peft, revision=revision ) # forever after, access self._model through self.model property self.model.eval() self.model.tie_weights() if gpus <= 1 and not parallelize: # place model onto device, if not using HF Accelerate in any form try: self.model.to(self.device) except ValueError: eval_logger.info( "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." ) self.tokenizer = transformers.AutoTokenizer.from_pretrained( pretrained if tokenizer is None else tokenizer, revision=revision, trust_remote_code=trust_remote_code, use_fast=use_fast_tokenizer, ) self.truncation = truncation self.vocab_size = self.tokenizer.vocab_size self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self._max_length = max_length self.batch_schedule = 1 self.batch_sizes = {} self.max_batch_size = max_batch_size if str(batch_size).startswith("auto"): batch_size = batch_size.split(":") self.batch_size_per_gpu = batch_size[0] self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 else: self.batch_size_per_gpu = int(batch_size) # multigpu data-parallel support when launched with accelerate if gpus > 1: if parallelize: if accelerator.num_processes > 1: raise RuntimeError( "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher." ) else: pass elif gpus > accelerator.num_processes: # TODO: make sure there's still never an edge case where we unintentionally default to CPU eval_logger.warning( "WARNING: The number of total system GPUs does not match the number of spawned processes. " "If you would like to use data parallelism, please launch the script " "with 'accelerate launch *script*'. " f"Current run will proceed with {accelerator.num_processes} devices." ) self._rank = accelerator.local_process_index self._world_size = accelerator.num_processes # manually set model to use gpu, for case where many GPUs available but # only seek to use one self._device = ( torch.device(f"cuda:{accelerator.local_process_index}") if torch.cuda.is_available() else torch.device("cpu") ) try: self.model.to(self.device) except ValueError: eval_logger.info( "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." ) else: assert accelerator.distributed_type in [ DistributedType.FSDP, DistributedType.MULTI_GPU, ], "Unsupported distributed type provided. Only DDP and FSDP are supported." if accelerator.distributed_type == DistributedType.FSDP: self._model = accelerator.prepare(self.model) else: self._model = accelerator.prepare_model( self.model, evaluation_mode=True ) self._device = torch.device(f"cuda:{accelerator.local_process_index}") self.accelerator = accelerator if self.accelerator.is_local_main_process: eval_logger.info(f"Using {gpus} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes @property def config(self): # return the associated transformers.AutoConfig for the given pretrained model. return self._config @property def model(self): # returns the model, unwrapping it if using Accelerate if hasattr(self, "accelerator"): return self.accelerator.unwrap_model(self._model) else: return self._model @property def eot_token_id(self): # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* return self.tokenizer.eos_token_id @property def max_length(self): if self._max_length: # if max length manually set, return it return self._max_length seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") for attr in seqlen_config_attrs: if hasattr(self.model.config, attr): return getattr(self.model.config, attr) if hasattr(self.tokenizer, "model_max_length"): if self.tokenizer.model_max_length == 1000000000000000019884624838656: return self._DEFAULT_MAX_LENGTH return self.tokenizer.model_max_length return self._DEFAULT_MAX_LENGTH @property def max_gen_toks(self) -> int: return 256 @property def batch_size(self): return self.batch_size_per_gpu @property def device(self): return self._device @property def rank(self): return self._rank @property def world_size(self): return self._world_size def _detect_batch_size(self, requests=None, pos: int = 0): if requests: _, context_enc, continuation_enc = requests[pos] max_length = len( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] ) max_context_enc = len(context_enc[-(self.max_length + 1) :]) max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) else: max_length = self.max_length # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=self.max_batch_size) def forward_batch(batch_size): if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: length = max(max_context_enc, max_cont_enc) batched_conts = torch.ones( (batch_size, length), device=self.device ).long() test_batch = torch.ones((batch_size, length), device=self.device).long() call_kwargs = { "attn_mask": test_batch, "labels": batched_conts, } else: call_kwargs = {} test_batch = torch.ones( (batch_size, max_length), device=self.device ).long() for _ in range(5): out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) out = out # Identity process so that it passes pre-commit return batch_size batch_size = forward_batch() if self.world_size > 1: # if multi-GPU, always take minimum over all selected batch sizes max_rnk_bs = torch.tensor([batch_size], device=self.device) gathered = ( self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() ) batch_size = min(gathered) utils.clear_torch_cache() return batch_size utils.clear_torch_cache() return batch_size def tok_encode(self, string: str, left_truncate_len=None): """ """ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: add_special_tokens = False elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: add_special_tokens = True encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) # left-truncate the encoded context to be at most `left_truncate_len` tokens long if left_truncate_len: encoding = encoding[-left_truncate_len:] return encoding def tok_batch_encode( self, strings: List[str], padding_side: str = "left", left_truncate_len: int = None, truncation: bool = False, ): # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. old_padding_side = self.tokenizer.padding_side self.tokenizer.padding_side = padding_side if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: add_special_tokens = False elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: add_special_tokens = True encoding = self.tokenizer( strings, truncation=truncation, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens, ) if left_truncate_len: encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["attention_mask"] = encoding["attention_mask"][ :, -left_truncate_len: ] self.tokenizer.padding_side = old_padding_side return encoding["input_ids"], encoding["attention_mask"] def tok_decode(self, tokens): if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: return self.tokenizer.decode(tokens) elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: return self.tokenizer.decode(tokens, skip_special_tokens=True) def _model_call(self, inps, attn_mask=None, labels=None): """ :param inps: torch.Tensor A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape [batch, sequence_ctx]. the size of sequence may vary from call to call :param attn_mask: torch.Tensor, optional A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM :param labels: torch.Tensor, optional A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM :return A torch tensor of shape [batch, sequence, vocab] with the logits returned from the model's decoder """ with torch.no_grad(): if attn_mask is not None or labels is not None: assert attn_mask is not None and labels is not None assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM return self.model( input_ids=inps, attention_mask=attn_mask, labels=labels ).logits else: assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM return self.model(inps).logits def _model_generate(self, context, max_length, stop, **generation_kwargs): # we require users to pass do_sample=True explicitly # for non-greedy gen. This should be reevaluated when considering beam search. if "do_sample" not in generation_kwargs.keys(): generation_kwargs["do_sample"] = False # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, 1, context.shape[0] ) return self.model.generate( input_ids=context, max_length=max_length, stopping_criteria=stopping_criteria, pad_token_id=self.eot_token_id, use_cache=True, **generation_kwargs, ) def _select_cont_toks(self, logits, contlen=None, inplen=None): if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: assert ( contlen and inplen ), "Must pass input len and cont. len to select scored logits for causal LM" # discard right-padding. # also discard the input/context tokens. we'll only score continuations. logits = logits[inplen - contlen : inplen] elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: assert ( contlen and not inplen ), "Selecting scored logits for Seq2SeqLM requires only cont. len" # only discard right-padding. # the logits input to this fn only contain decoder-side tokens. logits = logits[:contlen] return logits def _encode_pair(self, context, continuation): n_spaces = len(context) - len(context.rstrip()) if n_spaces > 0: continuation = context[-n_spaces:] + continuation context = context[:-n_spaces] whole_enc = self.tok_encode(context + continuation) context_enc = self.tok_encode(context) context_enc_len = len(context_enc) continuation_enc = whole_enc[context_enc_len:] return context_enc, continuation_enc def loglikelihood(self, requests): new_reqs = [] for context, continuation in [req.args for req in requests]: if context == "": # end of text as context context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( continuation ) else: context_enc, continuation_enc = self._encode_pair(context, continuation) new_reqs.append(((context, continuation), context_enc, continuation_enc)) return self._loglikelihood_tokens(new_reqs) def loglikelihood_rolling(self, requests): loglikelihoods = [] adaptive_batch_size = None if self.batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() print(f"Determined Largest batch size: {batch_size}") adaptive_batch_size = batch_size for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): rolling_token_windows = list( map( utils.make_disjoint_window, utils.get_rolling_token_windows( token_list=self.tok_encode(string), prefix_token=self.eot_token_id, max_seq_len=self.max_length, context_len=1, ), ) ) # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case rolling_token_windows = [(None,) + x for x in rolling_token_windows] pad_amnt = 0 if self.world_size > 1: # We pad out the external document-level iterator so the inner iterator doesn't hang mytensor = torch.tensor(len(rolling_token_windows), device=self.device) gathered = ( self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() ) pad_amnt = max(gathered) - gathered[self.rank] if pad_amnt > 0: rolling_token_windows += pad_amnt * [rolling_token_windows[0]] string_nll = self._loglikelihood_tokens( rolling_token_windows, disable_tqdm=True, override_bs=adaptive_batch_size, ) if (self.world_size > 1) and (pad_amnt > 0): string_nll = [x[0] for x in string_nll[:-pad_amnt]] else: # discard is_greedy string_nll = [x[0] for x in string_nll] string_nll = sum(string_nll) loglikelihoods.append(string_nll) return loglikelihoods def _loglikelihood_tokens( self, requests, disable_tqdm: bool = False, override_bs=None ): # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] def _collate(x): # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = x[1] + x[2] return -len(toks), tuple(toks) re_ord = utils.Reorderer(requests, _collate) n_reordered_requests = len(re_ord.get_reordered()) # automatic (variable) batch size detection for vectorization # pull longest context sample from request def _batch_scheduler(pos): sched = pos // int(n_reordered_requests / self.batch_schedule) if sched in self.batch_sizes: return self.batch_sizes[sched] if (len(self.batch_sizes) > 1) and ( self.batch_sizes[sched - 1] == self.max_batch_size ): # if previous batch size is already maximal, skip recomputation self.batch_sizes[sched] = self.max_batch_size return self.batch_sizes[sched] print( f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" ) self.batch_sizes[sched] = self._detect_batch_size( re_ord.get_reordered(), pos ) print(f"Determined largest batch size: {self.batch_sizes[sched]}") return self.batch_sizes[sched] for chunk in utils.chunks( tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0, fn=_batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None, ): inps = [] cont_toks_list = [] inplens = [] conts = [] encoder_attns = [] padding_len_inp = None padding_len_cont = None # because vectorizing is annoying, we first convert each (context, continuation) pair to padded # tensors, then we pack them together into a batch, call the model, and then pick it all apart # again because vectorizing is annoying for _, context_enc, continuation_enc in chunk: # sanity check assert len(context_enc) > 0 assert len(continuation_enc) > 0 assert len(continuation_enc) <= self.max_length # how this all works (illustrated on a causal decoder-only setup): # CTX CONT # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # model \ \ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: inp = torch.tensor( (context_enc)[-self.max_length :], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape # build encoder attn masks encoder_attns.append(torch.ones_like(inp)) cont = torch.tensor( (continuation_enc)[-self.max_length :], # TODO: left-shift these? # TODO: our code assumes we never end up truncating conts for either model type dtype=torch.long, device=self.device, ) (contlen,) = cont.shape conts.append(cont) padding_len_cont = ( max(padding_len_cont, contlen) if padding_len_cont is not None else contlen ) padding_len_inp = ( max(padding_len_inp, inplen) if padding_len_inp is not None else inplen ) inps.append(inp) # [1, inp_length] cont_toks_list.append(continuation_enc) inplens.append(inplen) # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: batched_inps = utils.pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # TODO: left-pad encoder inps and mask? batched_inps = utils.pad_and_concat( padding_len_inp, inps ) # [batch, padding_len_inp] batched_conts = utils.pad_and_concat( padding_len_cont, conts ) # [batch, padding_len_cont] batched_encoder_mask = utils.pad_and_concat( padding_len_inp, encoder_attns ) # [batch, padding_len_inp] call_kwargs = { "attn_mask": batched_encoder_mask, "labels": batched_conts, } multi_logits = F.log_softmax( self._model_call(batched_inps, **call_kwargs), dim=-1 ) # [batch, padding_length (inp or cont), vocab] for (cache_key, _, _), logits, inplen, cont_toks in zip( chunk, multi_logits, inplens, cont_toks_list ): # Slice to original seq length contlen = len(cont_toks) # take only logits in the continuation # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable ctx_len = ( inplen + (logits.shape[0] - padding_len_inp) if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None ) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) cont_toks = torch.tensor( cont_toks, dtype=torch.long, device=self.device ).unsqueeze( 0 ) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( -1 ) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) res.append(answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer) return re_ord.get_original(res) def greedy_until(self, requests): res = defaultdict(list) re_ords = {} def _collate(x): # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tok_encode(x[0]) return -len(toks), x[0] # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. grouper = utils.Grouper(requests, lambda x: str(x.args[1])) for key, reqs in grouper.get_grouped().items(): # within each set of reqs for given kwargs, we reorder by token length, descending. re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate) pbar = tqdm(total=len(requests), disable=(self.rank != 0)) # for each different set of kwargs, we execute all requests, by batch. for key, re_ord in re_ords.items(): for chunk in utils.chunks( re_ord.get_reordered(), self.batch_size, ): contexts, all_gen_kwargs = zip(*chunk) # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. until = None if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "until" in kwargs.keys(): until = kwargs.pop("until") if isinstance(until, str): until = [kwargs] elif not isinstance(until, list): raise ValueError( f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" ) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {kwargs}" ) if not until: until = [self.tok_decode(self.eot_token_id)] if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # first stop sequence is used to halt generation upon encountering primary_until = [until[0]] # set the max length in tokens of inputs ("context_enc") if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length # encode, pad, and truncate contexts for this batch context_enc, attn_masks = self.tok_batch_encode( contexts, left_truncate_len=max_ctx_len, truncation=self.truncation, ) context_enc = context_enc.to(self.device) attn_masks = attn_masks.to(self.device) if "max_length" not in kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks # perform batched generation cont = self._model_generate( context=context_enc, attention_mask=attn_masks, stop=primary_until, **kwargs, ) cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: cont_toks = cont_toks[context_enc.shape[1] :] s = self.tok_decode(cont_toks) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' s = s.split(term)[0] res[key].append(s) self.cache_hook.add_partial( "greedy_until", (context, gen_kwargs), s ) pbar.update(1) # reorder this group of results back to original unsorted form res[key] = re_ord.get_original(res[key]) pbar.close() return grouper.get_original(res)