Unverified Commit e9037708 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Support different image sizes in prefill in VLMs (#2065)

When a batch contained images if different sizes during prefill, the
server would fail (see e.g. #2056). Images were processed separately and
then concatenated. However, this can fail for images with different sizes.

Fix this by preprocessing all images in the batch together, so that the
image processor can ensure that all image tensors have compatible sizes.
parent 445f3135
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 8,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2502,
"logprob": -1.734375,
"special": false,
"text": "image"
},
{
"id": 2196,
"logprob": -0.5756836,
"special": false,
"text": " result"
},
{
"id": 604,
"logprob": -0.007843018,
"special": false,
"text": " for"
},
{
"id": 12254,
"logprob": -1.7167969,
"special": false,
"text": " chicken"
},
{
"id": 611,
"logprob": -0.17053223,
"special": false,
"text": " on"
},
{
"id": 573,
"logprob": -0.7626953,
"special": false,
"text": " the"
},
{
"id": 8318,
"logprob": -0.02709961,
"special": false,
"text": " beach"
},
{
"id": 1,
"logprob": -0.20739746,
"special": true,
"text": "<eos>"
}
],
"top_tokens": null
},
"generated_text": "image result for chicken on the beach"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 12,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 450,
"logprob": -0.26342773,
"special": false,
"text": " The"
},
{
"id": 21282,
"logprob": -0.01838684,
"special": false,
"text": " cow"
},
{
"id": 322,
"logprob": -0.18041992,
"special": false,
"text": " and"
},
{
"id": 521,
"logprob": -0.62841797,
"special": false,
"text": " ch"
},
{
"id": 21475,
"logprob": -0.0037956238,
"special": false,
"text": "icken"
},
{
"id": 526,
"logprob": -0.018737793,
"special": false,
"text": " are"
},
{
"id": 373,
"logprob": -1.0820312,
"special": false,
"text": " on"
},
{
"id": 263,
"logprob": -0.5083008,
"special": false,
"text": " a"
},
{
"id": 25695,
"logprob": -0.07128906,
"special": false,
"text": " beach"
},
{
"id": 29889,
"logprob": -0.12573242,
"special": false,
"text": "."
},
{
"id": 32002,
"logprob": -0.0029792786,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -0.00024962425,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": " The cow and chicken are on a beach."
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 415,
"logprob": -0.04421997,
"special": false,
"text": " The"
},
{
"id": 12072,
"logprob": -0.13500977,
"special": false,
"text": " cow"
},
{
"id": 349,
"logprob": -0.06750488,
"special": false,
"text": " is"
},
{
"id": 6328,
"logprob": -0.6352539,
"special": false,
"text": " standing"
},
{
"id": 356,
"logprob": -0.16186523,
"special": false,
"text": " on"
},
{
"id": 272,
"logprob": -0.5078125,
"special": false,
"text": " the"
},
{
"id": 10305,
"logprob": -0.017913818,
"special": false,
"text": " beach"
},
{
"id": 304,
"logprob": -1.5205078,
"special": false,
"text": " and"
},
{
"id": 272,
"logprob": -0.029174805,
"special": false,
"text": " the"
},
{
"id": 13088,
"logprob": -0.003479004,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.0035095215,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.3088379,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.027755737,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.31884766,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.047943115,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002925396,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.02935791,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.031219482,
"special": false,
"text": "."
},
{
"id": 32002,
"logprob": -0.00034475327,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": " The cow is standing on the beach and the chicken is sitting on a pile of money."
}
...@@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle): ...@@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
return flash_pali_gemma_handle.client return flash_pali_gemma_handle.client
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach(): def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file: with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()) encoded_string = base64.b64encode(image_file.read())
...@@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): ...@@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
assert response.generated_text == "beach" assert response.generated_text == "beach"
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20,
)
# Is PaliGemma not able to handle two separate images? At least we
# get output showing that both images are used.
assert (
response.generated_text == "image result for chicken on the beach"
), f"{repr(response.generated_text)}"
assert response == response_snapshot
...@@ -23,6 +23,12 @@ def get_chicken(): ...@@ -23,6 +23,12 @@ def get_chicken():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot): async def test_idefics(idefics, response_snapshot):
chicken = get_chicken() chicken = get_chicken()
...@@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot): ...@@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await idefics.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
)
assert (
response.generated_text == " The cow and chicken are on a beach."
), f"{repr(response.generated_text)}"
assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot): async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken() chicken = get_chicken()
......
...@@ -9,6 +9,12 @@ def get_chicken(): ...@@ -9,6 +9,12 @@ def get_chicken():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher): def flash_idefics2_next_handle(launcher):
with launcher( with launcher(
...@@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot ...@@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
)
assert (
response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20
assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot): async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
......
...@@ -53,7 +53,9 @@ def image_text_replacement(image_input, config, image_id) -> str: ...@@ -53,7 +53,9 @@ def image_text_replacement(image_input, config, image_id) -> str:
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}") logger.info(
f"Found {num_features} features in image of resolution {height}x{width}"
)
return "<image>" * num_features return "<image>" * num_features
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
...@@ -133,23 +135,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch): ...@@ -133,23 +135,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
): ):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
if config.model_type == "llava_next":
images.append(image)
else:
images.append([image])
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
image_inputs = processor.image_processor(images, return_tensors="pt")
else:
image_inputs = None
batch_inputs = [] batch_inputs = []
image_inputs = []
max_truncation = 0 max_truncation = 0
image_id = 0
for r in requests: for r in requests:
full_text = "" full_text = ""
image_id = 0
for chunk in r.input_chunks.chunks: for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk") chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text": if chunk_type == "text":
full_text += chunk.text full_text += chunk.text
elif chunk_type == "image": elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data)) full_text += image_text_replacement(image_inputs, config, image_id)
image_input = processor.image_processor(image, return_tensors="pt") image_id += 1
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
...@@ -160,24 +180,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): ...@@ -160,24 +180,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
max_length=max_truncation, max_length=max_truncation,
add_special_tokens=not config.model_type == "paligemma", add_special_tokens=not config.model_type == "paligemma",
)["input_ids"] )["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment