Commit 25e46a3a authored by wanglong3's avatar wanglong3
Browse files

Support custom multimodal dataset

parent 4e8af7e8
......@@ -1626,8 +1626,32 @@ class CustomDataset(BenchmarkDataset):
raise NotImplementedError(
"Only JSONL format is supported for CustomDataset.")
random.seed(self.random_seed)
random.shuffle(self.data)
# random.seed(self.random_seed)
# random.shuffle(self.data)
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[
Union[MultiModalDataDict, dict, list[dict]]
] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
format.
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
if isinstance(mm_content, list):
content.extend(cast(list[dict[str, Any]], mm_content))
elif isinstance(mm_content, dict):
content.append(mm_content)
else:
raise TypeError(
"Could not process multimodal content of type: " +
f"{type(mm_content)}"
)
return [{"role": "user", "content": content}]
def sample(
self,
......@@ -1656,23 +1680,15 @@ class CustomDataset(BenchmarkDataset):
break
prompt = item["prompt"]
# apply template
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
mm_contents = item["image"]
prompt = self.apply_multimodal_chat_transformation(prompt, mm_contents)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
expected_output_len=item["expected_output_len"],
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
......
......@@ -250,12 +250,7 @@ async def async_request_openai_chat_completions(
"model":
request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"messages": [
{
"role": "user",
"content": content
},
],
"messages": request_func_input.prompt,
"temperature":
0.0,
"max_completion_tokens":
......@@ -281,7 +276,6 @@ async def async_request_openai_chat_completions(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment