import asyncio import base64 import copy import itertools import json import os from functools import cached_property from io import BytesIO from operator import itemgetter from typing import Any, Dict, List, Optional, Tuple, Union from PIL import Image from tenacity import retry, stop_after_attempt, wait_exponential from tqdm import tqdm from lm_eval.api.instance import Instance from lm_eval.api.registry import register_model from lm_eval.models.api_models import JsonChatStr, TemplateAPI from lm_eval.models.utils import Collator, handle_stop_sequences from lm_eval.utils import eval_logger @register_model("local-completions") class LocalCompletionsAPI(TemplateAPI): def __init__( self, base_url=None, tokenizer_backend="huggingface", **kwargs, ): eval_logger.info("Use the AI_API_KEY environment variable to set the API key.") super().__init__( base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs ) def _create_payload( self, messages: Union[List[List[int]], List[dict], List[str], str], generate=False, gen_kwargs: Optional[dict] = None, seed: int = 1234, eos=None, **kwargs, ) -> dict: if generate: gen_kwargs.pop("do_sample", False) if "max_tokens" in gen_kwargs: max_tokens = gen_kwargs.pop("max_tokens") else: max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) temperature = gen_kwargs.pop("temperature", 0) stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos) return { "prompt": messages, "model": self.model, "max_tokens": max_tokens, "temperature": temperature, "stop": stop, "seed": seed, **gen_kwargs, } else: return { "model": self.model, "prompt": messages, "temperature": 0, "max_tokens": 1, "logprobs": 1, "seed": seed, "echo": True, } @staticmethod def parse_logprobs( outputs: Union[Dict, List[Dict]], tokens: List[List[int]] = None, ctxlens: List[int] = None, **kwargs, ) -> List[Tuple[float, bool]]: res = [] if not isinstance(outputs, list): outputs = [outputs] for out in outputs: for choice, ctxlen in zip( sorted(out["choices"], key=itemgetter("index")), ctxlens ): assert ctxlen > 0, "Context length must be greater than 0" logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1]) tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1] top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1] is_greedy = True for tok, top in zip(tokens_logprobs, top_logprobs): if tok != max(top.values()): is_greedy = False break res.append((logprobs, is_greedy)) return res @staticmethod def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: res = [] if not isinstance(outputs, list): outputs = [outputs] for out in outputs: tmp = [None] * len(out["choices"]) for choices in out["choices"]: tmp[choices["index"]] = choices["text"] res = res + tmp return res @property def api_key(self): return os.environ.get("AI_API_KEY", "") @register_model("local-chat-completions") class LocalChatCompletion(LocalCompletionsAPI): def __init__( self, base_url=None, tokenizer_backend=None, tokenized_requests=False, **kwargs, ): eval_logger.warning( "chat-completions endpoint requires the `--apply_chat_template` flag." ) super().__init__( base_url=base_url, tokenizer_backend=tokenizer_backend, tokenized_requests=tokenized_requests, **kwargs, ) if self._batch_size > 1: eval_logger.warning( "Chat completions does not support batching. Defaulting to batch size 1." ) self._batch_size = 1 def _create_payload( self, messages: List[Dict], generate=False, gen_kwargs: dict = None, seed=1234, eos=None, **kwargs, ) -> dict: assert type(messages) is not str, ( "chat-completions require the --apply_chat_template flag." ) gen_kwargs.pop("do_sample", False) if "max_tokens" in gen_kwargs: max_tokens = gen_kwargs.pop("max_tokens") else: max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) temperature = gen_kwargs.pop("temperature", 0) stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos) if not isinstance(stop, (list, tuple)): stop = [stop] return { "messages": messages, "model": self.model, "max_tokens": max_tokens, "temperature": temperature, "stop": stop[:4], "seed": seed, **gen_kwargs, } @staticmethod def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: res = [] if not isinstance(outputs, list): outputs = [outputs] for out in outputs: tmp = [None] * len(out["choices"]) for choices in out["choices"]: tmp[choices["index"]] = choices["message"]["content"] res = res + tmp return res def tok_encode( self, string: Union[str, Any], left_truncate_len=None, add_special_tokens=None, **kwargs, ) -> Union[List[str], List[int], Any]: return string def loglikelihood(self, requests, **kwargs): raise NotImplementedError( "Loglikelihood is not supported for chat completions. Consider using the completions API instead." ) @register_model( "openai-completions", ) class OpenAICompletionsAPI(LocalCompletionsAPI): def __init__( self, base_url="https://api.openai.com/v1/completions", tokenizer_backend="tiktoken", **kwargs, ): super().__init__( base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs ) @cached_property def api_key(self): """Override this property to return the API key for the API request.""" key = os.environ.get("OPENAI_API_KEY", None) if key is None: raise ValueError( "API key not found. Please set the `OPENAI_API_KEY` environment variable." ) return key def loglikelihood(self, requests, **kwargs): assert self.model in [ "babbage-002", "davinci-002", ], ( f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}." ) return super().loglikelihood(requests, **kwargs) def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: return "" @register_model("openai-chat-completions") class OpenAIChatCompletion(LocalChatCompletion): def __init__( self, base_url="https://api.openai.com/v1/chat/completions", tokenizer_backend=None, tokenized_requests=False, **kwargs, ): if "o1" in kwargs.get("model", ""): eval_logger.warning( "o1 models do not support `stop` and only support temperature=1" ) super().__init__( base_url=base_url, tokenizer_backend=tokenizer_backend, tokenized_requests=tokenized_requests, **kwargs, ) @cached_property def api_key(self): """Override this property to return the API key for the API request.""" key = os.environ.get("OPENAI_API_KEY", None) if key is None: raise ValueError( "API key not found. Please set the `OPENAI_API_KEY` environment variable." ) return key def loglikelihood(self, requests, **kwargs): raise NotImplementedError( "Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation." ) def _create_payload( self, messages: List[Dict], generate=False, gen_kwargs: dict = None, seed=1234, eos="<|endoftext|>", **kwargs, ) -> dict: assert type(messages) is not str, ( "chat-completions require the --apply_chat_template flag." ) gen_kwargs.pop("do_sample", False) if "max_tokens" in gen_kwargs: max_tokens = gen_kwargs.pop("max_tokens") else: max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) temperature = gen_kwargs.pop("temperature", 0) stop = handle_stop_sequences(gen_kwargs.pop("until", ["<|endoftext|>"]), eos) if not isinstance(stop, (list, tuple)): stop = [stop] output = { "messages": messages, "model": self.model, "max_completion_tokens": max_tokens, "temperature": temperature, "stop": stop[:4], "seed": seed, **gen_kwargs, } if "o1" in self.model: output.pop("stop") output["temperature"] = 1 return output @register_model("pixtral-api") class PixtralAPI(LocalChatCompletion): MULTIMODAL = True DEFAULT_IMAGE_PLACEHOLDER = "" def __init__( self, max_images: int = 999, **kwargs, ): self.max_images = max_images super().__init__( tokenizer_backend=None, tokenized_requests=False, model="mistralai/Pixtral-12B-2409", **kwargs, ) def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: res = [] def _collate_gen(_requests): # sort by the length of the non-tokenized contexts return -len(_requests[0]) # Let the API deal with tokenization requests, all_gen_kwargs, aux_args = zip(*(req.args for req in requests)) if self.tokenized_requests: encodings_list = self.tok_encode( requests, add_special_tokens=self.add_bos_token ) else: requests = [ self.update_json_chat_str_with_image(req, pil_image["visual"]) for req, pil_image in zip(requests, aux_args) ] encodings_list = [None] * len(requests) requests = [ (a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list) ] re_ord = Collator( requests, sort_fn=_collate_gen, group_by="gen_kwargs", ) chunked = re_ord.get_batched( n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None ) if self._concurrent <= 1: pbar = tqdm(desc="Requesting API", total=len(requests)) for chunk in chunked: contexts, all_gen_kwargs, encodings_list = zip(*chunk) req = encodings_list if self.tokenized_requests else contexts outputs = retry( stop=stop_after_attempt(self.max_retries), wait=wait_exponential(multiplier=0.5, min=1, max=10), reraise=True, )(self.model_call)( messages=req, generate=True, gen_kwargs=copy.deepcopy(all_gen_kwargs[0]), ) for generated_text, context in zip( self.parse_generations( outputs=outputs, contexts=contexts, ), contexts, ): if generated_text is not None: res.append(generated_text) # partial caching if context is not None: self.cache_hook.add_partial( "generate_until", (context, all_gen_kwargs[0]), generated_text, ) pbar.update(1) else: for chunk in chunked: contexts, all_gen_kwargs, encodings_list = zip(*chunk) req = encodings_list if self.tokenized_requests else contexts results = itertools.chain.from_iterable( asyncio.run( self.get_batched_requests( req, cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts], generate=True, gen_kwargs=copy.deepcopy(all_gen_kwargs[0]), ) ) ) res.extend(results) return re_ord.get_original(res) @staticmethod def encode_pillow_image(img): if img.mode == "P": img = img.convert("RGB") if img.mode == "RGBA": # Create a white background background = Image.new("RGB", img.size, (255, 255, 255)) # Paste the image on the background. # The alpha channel is automatically used as mask background.paste(img, mask=img.split()[3]) img = background buffered = BytesIO() img.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def update_json_chat_str_with_image( self, json_chat_str, pil_images: Union["Image.Image", List["Image.Image"]] ): # Parse the JSON string chat_data = json.loads(json_chat_str.prompt) # Convert single image to list for consistency if not isinstance(pil_images, list): pil_images = [pil_images] # Encode the Pillow image(s) base64_images = [self.encode_pillow_image(img) for img in pil_images] # Update the image_url(s) in the chat data image_index = 0 for message in chat_data: if message["role"] == "user": for content in message["content"]: if content["type"] == "image_url": if image_index < len(base64_images): content["image_url"] = { "url": f"data:image/jpeg;base64,{base64_images[image_index]}" } image_index += 1 else: # If we run out of images, set to None or handle as needed content["image_url"] = None # Update the JsonChatStr object with the new JSON string json_chat_str = JsonChatStr(json.dumps(chat_data)) return json_chat_str def apply_chat_template( self, chat_history: List[Dict[str, str]] ) -> Union[str, JsonChatStr]: """Applies a chat template to a list of chat history between user and model.""" if self.tokenizer_backend == "huggingface" and self.tokenized_requests: return self.tokenizer.apply_chat_template( chat_history, tokenize=False, add_generation_prompt=True ) else: # bit of a hack. We'll load back before sending to the API new_messages = [] for message in chat_history: if message["role"] == "user": # Split the content at placeholder parts = message["content"].split("") new_content = [ {"type": "text", "text": parts[0].strip()}, {"type": "image_url", "image_url": None}, ] if len(parts) > 1: new_content.append({"type": "text", "text": parts[1].strip()}) new_messages.append( {"role": message["role"], "content": new_content} ) else: # For non-user messages, keep the format as is new_messages.append(message) return JsonChatStr(json.dumps(new_messages))