Unverified Commit 2cfdd0a2 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

use images with api models (#2981)

* use images with apis

* pacify pre-commit
parent 178fa84d
......@@ -6,6 +6,7 @@ import json
import logging
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
......@@ -30,7 +31,9 @@ except ModuleNotFoundError:
pass
import base64
from importlib.util import find_spec
from io import BytesIO
from lm_eval import utils
from lm_eval.api.instance import Instance
......@@ -38,6 +41,10 @@ from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token
if TYPE_CHECKING:
from PIL import Image
eval_logger = logging.getLogger(__name__)
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
......@@ -51,7 +58,52 @@ class JsonChatStr(NamedTuple):
return self.prompt.encode(encoding)
def create_image_prompt(
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
) -> dict:
"""
Parameters
----------
img : list[PIL.Image.Image]
The list of images to encode to base64
chat : dict
fmt : str, optional
Any format Pillow understands (e.g. "PNG", "JPEG").
Defaults to "PNG".
Returns
-------
dict
"""
images = []
for img in imgs:
buf = BytesIO()
img.save(buf, format=fmt)
img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
img_dict = {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "auto"},
}
images.append(img_dict)
# chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
# with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
# currently we do not support few-shots so only one user message
# text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
if isinstance(chat[-1]["content"], list):
chat[-1]["content"] = images + chat[-1]["content"]
else:
text_content = {"type": "text", "text": chat[-1]["content"]}
chat[-1]["content"] = images + [text_content]
chat[-1].pop("type")
return chat
class TemplateAPI(TemplateLM):
MULTIMODAL = True
def __init__(
self,
model: str = None,
......@@ -83,6 +135,7 @@ class TemplateAPI(TemplateLM):
eos_string: str = None,
# timeout in seconds
timeout: int = 300,
max_images: int = 1,
**kwargs,
) -> None:
super().__init__()
......@@ -129,6 +182,7 @@ class TemplateAPI(TemplateLM):
self.verify_certificate = verify_certificate
self._eos_string = eos_string
self.timeout = int(timeout)
self.max_images = int(max_images)
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None:
......@@ -265,7 +319,12 @@ class TemplateAPI(TemplateLM):
)
else:
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(json.dumps(chat_history, ensure_ascii=False))
return JsonChatStr(
json.dumps(
[{**item, "type": "text"} for item in chat_history],
ensure_ascii=False,
)
)
@cached_property
def eot_token_id(self) -> Optional[int]:
......@@ -578,7 +637,28 @@ class TemplateAPI(TemplateLM):
return -len(_requests[0])
# Let the API deal with tokenization
requests, all_gen_kwargs = zip(*(req.args for req in requests))
if len(requests[0].args) > 2:
assert self.tokenizer is None, (
"tokenizer is not supported for multimodal requests yet!"
)
eval_logger.info(
f"Using max_images {self.max_images}. Set in the model args."
)
requests, all_gen_kwargs, auxiliary_args = zip(
*(req.args for req in requests)
)
requests = tuple(
JsonChatStr(
json.dumps(
create_image_prompt(
y["visual"][: self.max_images], json.loads(x.prompt)
)
)
)
for x, y in zip(requests, auxiliary_args)
)
else:
requests, all_gen_kwargs = zip(*(req.args for req in requests))
if self.tokenized_requests:
encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment