Commit 2ef15732 authored by Baber's avatar Baber
Browse files

add pixtral_api

parent 3c772593
import asyncio
import base64
import copy
import itertools
import json
import os
from functools import cached_property
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
from PIL import Image
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.api_models import JsonChatStr, TemplateAPI
from lm_eval.models.utils import Collator
from lm_eval.utils import eval_logger
......@@ -238,3 +250,186 @@ class OpenAIChatCompletion(LocalChatCompletion):
raise NotImplementedError(
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
)
@register_model("pixtral-api")
class PixtralAPI(LocalChatCompletion):
MULTIMODAL = True
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
def __init__(
self,
max_images: int = 999,
**kwargs,
):
self.max_images = max_images
super().__init__(
tokenizer_backend=None,
tokenized_requests=False,
model="mistralai/Pixtral-12B-2409",
**kwargs,
)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate_gen(_requests):
# sort by the length of the non-tokenized contexts
return -len(_requests[0])
# Let the API deal with tokenization
requests, all_gen_kwargs, aux_args = zip(*(req.args for req in requests))
if self.tokenized_requests:
encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token
)
else:
requests = [
self.update_json_chat_str_with_image(req, pil_image["visual"])
for req, pil_image in zip(requests, aux_args)
]
encodings_list = [None] * len(requests)
requests = [
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
]
re_ord = Collator(
requests,
sort_fn=_collate_gen,
group_by="gen_kwargs",
)
chunked = re_ord.get_batched(
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
outputs = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.model_call)(
messages=req,
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
for generated_text, context in zip(
self.parse_generations(
outputs=outputs,
contexts=contexts,
),
contexts,
):
if generated_text is not None:
res.append(generated_text)
# partial caching
if context is not None:
self.cache_hook.add_partial(
"generate_until",
(context, all_gen_kwargs[0]),
generated_text,
)
pbar.update(1)
else:
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
req,
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
)
)
res.extend(results)
return re_ord.get_original(res)
@staticmethod
def encode_pillow_image(img):
if img.mode == "P":
img = img.convert("RGB")
if img.mode == "RGBA":
# Create a white background
background = Image.new("RGB", img.size, (255, 255, 255))
# Paste the image on the background.
# The alpha channel is automatically used as mask
background.paste(img, mask=img.split()[3])
img = background
buffered = BytesIO()
img.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def update_json_chat_str_with_image(
self, json_chat_str, pil_images: Union["Image.Image", List["Image.Image"]]
):
# Parse the JSON string
chat_data = json.loads(json_chat_str.prompt)
# Convert single image to list for consistency
if not isinstance(pil_images, list):
pil_images = [pil_images]
# Encode the Pillow image(s)
base64_images = [self.encode_pillow_image(img) for img in pil_images]
# Update the image_url(s) in the chat data
image_index = 0
for message in chat_data:
if message["role"] == "user":
for content in message["content"]:
if content["type"] == "image_url":
if image_index < len(base64_images):
content["image_url"] = {
"url": f"data:image/jpeg;base64,{base64_images[image_index]}"
}
image_index += 1
else:
# If we run out of images, set to None or handle as needed
content["image_url"] = None
# Update the JsonChatStr object with the new JSON string
json_chat_str = JsonChatStr(json.dumps(chat_data))
return json_chat_str
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
else:
# bit of a hack. We'll load back before sending to the API
new_messages = []
for message in chat_history:
if message["role"] == "user":
# Split the content at <image> placeholder
parts = message["content"].split("<image>")
new_content = [
{"type": "text", "text": parts[0].strip()},
{"type": "image_url", "image_url": None},
]
if len(parts) > 1:
new_content.append({"type": "text", "text": parts[1].strip()})
new_messages.append(
{"role": message["role"], "content": new_content}
)
else:
# For non-user messages, keep the format as is
new_messages.append(message)
return JsonChatStr(json.dumps(new_messages))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment