Commit 25e46a3a authored by wanglong3's avatar wanglong3
Browse files

Support custom multimodal dataset

parent 4e8af7e8
...@@ -1626,9 +1626,33 @@ class CustomDataset(BenchmarkDataset): ...@@ -1626,9 +1626,33 @@ class CustomDataset(BenchmarkDataset):
raise NotImplementedError( raise NotImplementedError(
"Only JSONL format is supported for CustomDataset.") "Only JSONL format is supported for CustomDataset.")
random.seed(self.random_seed) # random.seed(self.random_seed)
random.shuffle(self.data) # random.shuffle(self.data)
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[
Union[MultiModalDataDict, dict, list[dict]]
] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
format.
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
if isinstance(mm_content, list):
content.extend(cast(list[dict[str, Any]], mm_content))
elif isinstance(mm_content, dict):
content.append(mm_content)
else:
raise TypeError(
"Could not process multimodal content of type: " +
f"{type(mm_content)}"
)
return [{"role": "user", "content": content}]
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
...@@ -1655,24 +1679,16 @@ class CustomDataset(BenchmarkDataset): ...@@ -1655,24 +1679,16 @@ class CustomDataset(BenchmarkDataset):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["prompt"] prompt = item["prompt"]
# apply template
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
mm_contents = item["image"]
prompt = self.apply_multimodal_chat_transformation(prompt, mm_contents)
sampled_requests.append( sampled_requests.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=item["expected_output_len"],
request_id=request_id_prefix + str(i), request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
...@@ -2720,4 +2736,4 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): ...@@ -2720,4 +2736,4 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
) )
random.shuffle(requests) random.shuffle(requests)
return requests return requests
\ No newline at end of file
...@@ -250,12 +250,7 @@ async def async_request_openai_chat_completions( ...@@ -250,12 +250,7 @@ async def async_request_openai_chat_completions(
"model": "model":
request_func_input.model_name request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name else request_func_input.model,
"messages": [ "messages": request_func_input.prompt,
{
"role": "user",
"content": content
},
],
"temperature": "temperature":
0.0, 0.0,
"max_completion_tokens": "max_completion_tokens":
...@@ -281,7 +276,6 @@ async def async_request_openai_chat_completions( ...@@ -281,7 +276,6 @@ async def async_request_openai_chat_completions(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0.0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment