Unverified Commit 697b0f71 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

[Refactor] image data process in bench_serving (#6879)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 132dad87
...@@ -388,7 +388,6 @@ async def async_request_sglang_generate( ...@@ -388,7 +388,6 @@ async def async_request_sglang_generate(
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
# print(chunk_bytes)
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st latency = time.perf_counter() - st
...@@ -655,6 +654,7 @@ class DatasetRow: ...@@ -655,6 +654,7 @@ class DatasetRow:
prompt: str prompt: str
prompt_len: int prompt_len: int
output_len: int output_len: int
image_data: Optional[str] = None
def sample_mmmu_requests( def sample_mmmu_requests(
...@@ -730,42 +730,50 @@ def sample_mmmu_requests( ...@@ -730,42 +730,50 @@ def sample_mmmu_requests(
buffered = io.BytesIO() buffered = io.BytesIO()
image.save(buffered, format="JPEG") image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
image_path = f"data:image/jpeg;base64,{img_str}" image_data = f"data:image/jpeg;base64,{img_str}"
else: else:
continue continue
# Extract the question # Extract the question
question = example.get("question") question = example.get("question")
# Create the prompt with image, question # Construct the prompt
prompt = f"Question: {question}\n\nAnswer: " prompt = f"Question: {question}\n\nAnswer: "
prompt = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_path}},
{"type": "text", "text": prompt},
],
}
],
add_generation_prompt=True,
tokenize=False,
)
prompt = f"<image>{image_path}</image>{prompt}"
# Calculate token lengths try:
# Note: This is approximate since we're not rendering the actual image tokens prompt = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_data},
},
{"type": "text", "text": prompt},
],
}
],
add_generation_prompt=True,
tokenize=False,
)
except Exception as e:
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
print(f"Error applying chat template: {e}, fallback to <image> tag")
prompt = f"<image>{prompt}"
# Calculate token lengths for text only (without image data)
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
prompt_len = ( prompt_len = len(prompt_token_ids)
len(prompt_token_ids) + 512
) # Add estimate for image tokens
output_len = fixed_output_len if fixed_output_len is not None else 256 output_len = fixed_output_len if fixed_output_len is not None else 256
filtered_dataset.append( filtered_dataset.append(
DatasetRow( DatasetRow(
prompt=prompt, prompt_len=prompt_len, output_len=output_len prompt=prompt,
prompt_len=prompt_len,
output_len=output_len,
image_data=image_data,
) )
) )
...@@ -1199,34 +1207,21 @@ async def benchmark( ...@@ -1199,34 +1207,21 @@ async def benchmark(
# Use the first request for all warmup iterations # Use the first request for all warmup iterations
test_request = input_requests[0] test_request = input_requests[0]
test_prompt, test_prompt_len, test_output_len = (
test_request.prompt,
test_request.prompt_len,
test_request.output_len,
)
if lora_names is not None and len(lora_names) != 0: if lora_names is not None and len(lora_names) != 0:
lora_name = lora_names[0] lora_name = lora_names[0]
else: else:
lora_name = None lora_name = None
if "<image>" in test_prompt:
import re
image_match = re.search(r"<image>(.*?)</image>(.*)", test_prompt, re.DOTALL)
image_data = image_match.group(1) if image_match else None
test_prompt = image_match.group(2) if image_match else test_prompt
else:
image_data = None
# Create the test input once # Create the test input once
test_input = RequestFuncInput( test_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=test_prompt, prompt=test_request.prompt,
api_url=api_url, api_url=api_url,
prompt_len=test_prompt_len, prompt_len=test_request.prompt_len,
output_len=min(test_output_len, 32), output_len=min(test_request.output_len, 32),
lora_name=lora_name, lora_name=lora_name,
image_data=image_data, image_data=test_request.image_data,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
...@@ -1271,36 +1266,23 @@ async def benchmark( ...@@ -1271,36 +1266,23 @@ async def benchmark(
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = (
request.prompt,
request.prompt_len,
request.output_len,
)
if lora_names is not None and len(lora_names) != 0: if lora_names is not None and len(lora_names) != 0:
idx = random.randint(0, len(lora_names) - 1) idx = random.randint(0, len(lora_names) - 1)
lora_name = lora_names[idx] lora_name = lora_names[idx]
else: else:
lora_name = None lora_name = None
if "<image>" in prompt:
import re
image_match = re.search(r"<image>(.*?)</image>(.*)", prompt, re.DOTALL)
image_data = image_match.group(1) if image_match else None
prompt = image_match.group(2) if image_match else prompt
else:
image_data = None
request_func_input = RequestFuncInput( request_func_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=prompt, prompt=request.prompt,
api_url=api_url, api_url=api_url,
prompt_len=prompt_len, prompt_len=request.prompt_len,
output_len=output_len, output_len=request.output_len,
lora_name=lora_name, lora_name=lora_name,
image_data=image_data, image_data=request.image_data,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, pbar=pbar) limited_request_func(request_func_input=request_func_input, pbar=pbar)
......
...@@ -175,6 +175,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -175,6 +175,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
if not image_data: if not image_data:
return None return None
# Ensure image_data is a list
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment