Commit 2ef15732 authored by Baber's avatar Baber
Browse files

add pixtral_api

parent 3c772593
import asyncio
import base64
import copy
import itertools
import json
import os import os
from functools import cached_property from functools import cached_property
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from PIL import Image
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI from lm_eval.models.api_models import JsonChatStr, TemplateAPI
from lm_eval.models.utils import Collator
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
...@@ -238,3 +250,186 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -238,3 +250,186 @@ class OpenAIChatCompletion(LocalChatCompletion):
raise NotImplementedError( raise NotImplementedError(
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation." "Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
) )
@register_model("pixtral-api")
class PixtralAPI(LocalChatCompletion):
MULTIMODAL = True
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
def __init__(
self,
max_images: int = 999,
**kwargs,
):
self.max_images = max_images
super().__init__(
tokenizer_backend=None,
tokenized_requests=False,
model="mistralai/Pixtral-12B-2409",
**kwargs,
)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate_gen(_requests):
# sort by the length of the non-tokenized contexts
return -len(_requests[0])
# Let the API deal with tokenization
requests, all_gen_kwargs, aux_args = zip(*(req.args for req in requests))
if self.tokenized_requests:
encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token
)
else:
requests = [
self.update_json_chat_str_with_image(req, pil_image["visual"])
for req, pil_image in zip(requests, aux_args)
]
encodings_list = [None] * len(requests)
requests = [
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
]
re_ord = Collator(
requests,
sort_fn=_collate_gen,
group_by="gen_kwargs",
)
chunked = re_ord.get_batched(
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
outputs = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.model_call)(
messages=req,
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
for generated_text, context in zip(
self.parse_generations(
outputs=outputs,
contexts=contexts,
),
contexts,
):
if generated_text is not None:
res.append(generated_text)
# partial caching
if context is not None:
self.cache_hook.add_partial(
"generate_until",
(context, all_gen_kwargs[0]),
generated_text,
)
pbar.update(1)
else:
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
req,
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
)
)
res.extend(results)
return re_ord.get_original(res)
@staticmethod
def encode_pillow_image(img):
if img.mode == "P":
img = img.convert("RGB")
if img.mode == "RGBA":
# Create a white background
background = Image.new("RGB", img.size, (255, 255, 255))
# Paste the image on the background.
# The alpha channel is automatically used as mask
background.paste(img, mask=img.split()[3])
img = background
buffered = BytesIO()
img.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def update_json_chat_str_with_image(
self, json_chat_str, pil_images: Union["Image.Image", List["Image.Image"]]
):
# Parse the JSON string
chat_data = json.loads(json_chat_str.prompt)
# Convert single image to list for consistency
if not isinstance(pil_images, list):
pil_images = [pil_images]
# Encode the Pillow image(s)
base64_images = [self.encode_pillow_image(img) for img in pil_images]
# Update the image_url(s) in the chat data
image_index = 0
for message in chat_data:
if message["role"] == "user":
for content in message["content"]:
if content["type"] == "image_url":
if image_index < len(base64_images):
content["image_url"] = {
"url": f"data:image/jpeg;base64,{base64_images[image_index]}"
}
image_index += 1
else:
# If we run out of images, set to None or handle as needed
content["image_url"] = None
# Update the JsonChatStr object with the new JSON string
json_chat_str = JsonChatStr(json.dumps(chat_data))
return json_chat_str
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
else:
# bit of a hack. We'll load back before sending to the API
new_messages = []
for message in chat_history:
if message["role"] == "user":
# Split the content at <image> placeholder
parts = message["content"].split("<image>")
new_content = [
{"type": "text", "text": parts[0].strip()},
{"type": "image_url", "image_url": None},
]
if len(parts) > 1:
new_content.append({"type": "text", "text": parts[1].strip()})
new_messages.append(
{"role": message["role"], "content": new_content}
)
else:
# For non-user messages, keep the format as is
new_messages.append(message)
return JsonChatStr(json.dumps(new_messages))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment