Unverified Commit ba871fb7 authored by Komal Kumar Teru's avatar Komal Kumar Teru Committed by GitHub
Browse files

[Misc] support arbitrary MM datasets in spec dec bench (#33486)


Signed-off-by: default avatarkkt-cohere <komal@cohere.com>
Signed-off-by: default avatarKomal Kumar Teru <162363718+kkt-cohere@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent ab374786
...@@ -32,6 +32,7 @@ th { ...@@ -32,6 +32,7 @@ th {
| HuggingFace-Blazedit | ✅ | ✅ | `vdaita/edit_5k_char`, `vdaita/edit_10k_char` | | HuggingFace-Blazedit | ✅ | ✅ | `vdaita/edit_5k_char`, `vdaita/edit_10k_char` |
| Spec Bench | ✅ | ✅ | `wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl` | | Spec Bench | ✅ | ✅ | `wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl` |
| Custom | ✅ | ✅ | Local file: `data.jsonl` | | Custom | ✅ | ✅ | Local file: `data.jsonl` |
| Custom MM | ✅ | ✅ | Local file: `mm_data.jsonl` |
Legend: Legend:
...@@ -133,6 +134,33 @@ vllm bench serve --port 9001 --save-result --save-detailed \ ...@@ -133,6 +134,33 @@ vllm bench serve --port 9001 --save-result --save-detailed \
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
#### Custom multimodal dataset
If the multimodal dataset you want to benchmark is not supported yet in vLLM, then you can benchmark on it using `CustomMMDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" and "image_files" field per entry, e.g., `mm_data.jsonl`:
```json
{"prompt": "How many animals are present in the given image?", "image_files": ["/path/to/image/folder/horsepony.jpg"]}
{"prompt": "What colour is the bird shown in the image?", "image_files": ["/path/to/image/folder/flycatcher.jpeg"]}
```
```bash
# need a model with vision capability here
vllm serve Qwen/Qwen2-VL-7B-Instruct
```
```bash
# run benchmarking script
vllm bench serve--save-result --save-detailed \
--backend openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name custom_mm \
--dataset-path <path-to-your-mm-data-jsonl> \
--allowed-local-media-path /path/to/image/folder
```
Note that we need to use the `openai-chat` backend and `/v1/chat/completions` endpoint for multimodal inputs.
#### VisionArena Benchmark for Vision Language Models #### VisionArena Benchmark for Vision Language Models
```bash ```bash
......
...@@ -5,7 +5,6 @@ from transformers import AutoTokenizer ...@@ -5,7 +5,6 @@ from transformers import AutoTokenizer
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples from vllm.benchmarks.datasets import add_dataset_parser, get_samples
from vllm.inputs import TokensPrompt
from vllm.v1.metrics.reader import Counter, Vector from vllm.v1.metrics.reader import Counter, Vector
try: try:
...@@ -56,6 +55,7 @@ def parse_args(): ...@@ -56,6 +55,7 @@ def parse_args():
default="eagle", default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"], choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
) )
parser.add_argument("--backend", type=str, default="openai")
parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-max", type=int, default=5)
parser.add_argument("--prompt-lookup-min", type=int, default=2) parser.add_argument("--prompt-lookup-min", type=int, default=2)
...@@ -75,12 +75,11 @@ def parse_args(): ...@@ -75,12 +75,11 @@ def parse_args():
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true") parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None) parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--allowed-local-media-path", type=str, default="")
return parser.parse_args() return parser.parse_args()
def main(args): def main(args):
args.endpoint_type = "openai-chat"
model_dir = args.model_dir model_dir = args.model_dir
if args.model_dir is None: if args.model_dir is None:
if args.custom_mm_prompts: if args.custom_mm_prompts:
...@@ -91,19 +90,25 @@ def main(args): ...@@ -91,19 +90,25 @@ def main(args):
) )
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir)
args.custom_skip_chat_template = True
if not args.custom_mm_prompts: if args.custom_mm_prompts:
prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
else:
prompts = get_samples(args, tokenizer) prompts = get_samples(args, tokenizer)
if args.enable_multimodal_chat:
llm_prompts = [p.prompt for p in prompts]
else:
# add_special_tokens is False to avoid adding bos twice # add_special_tokens is False to avoid adding bos twice
# when using chat templates # when using chat templates
prompt_ids = [ llm_prompts = [
tokenizer.encode(prompt.prompt, add_special_tokens=False) {
"prompt_token_ids": tokenizer.encode(
prompt.prompt, add_special_tokens=False
),
"multi_modal_data": prompt.multi_modal_data,
}
for prompt in prompts for prompt in prompts
] ]
else:
prompts = get_custom_mm_prompts(args.num_prompts)
if args.method == "eagle" or args.method == "eagle3": if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir eagle_dir = args.eagle_dir
if args.method == "eagle" and eagle_dir is None: if args.method == "eagle" and eagle_dir is None:
...@@ -154,16 +159,17 @@ def main(args): ...@@ -154,16 +159,17 @@ def main(args):
limit_mm_per_prompt={"image": 5}, limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True, disable_chunked_mm_input=True,
max_num_seqs=args.max_num_seqs, max_num_seqs=args.max_num_seqs,
allowed_local_media_path=args.allowed_local_media_path,
) )
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
if not args.custom_mm_prompts: if args.backend == "openai-chat":
outputs = llm.chat(llm_prompts, sampling_params=sampling_params)
else:
outputs = llm.generate( outputs = llm.generate(
[TokensPrompt(prompt_token_ids=x) for x in prompt_ids], llm_prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)
# print the generated text # print the generated text
if args.print_output: if args.print_output:
...@@ -219,6 +225,8 @@ def main(args): ...@@ -219,6 +225,8 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
args.enable_multimodal_chat = args.backend == "openai-chat"
acceptance_length = main(args) acceptance_length = main(args)
if args.test: if args.test:
......
...@@ -1335,6 +1335,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1335,6 +1335,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"random-rerank", "random-rerank",
"hf", "hf",
"custom", "custom",
"custom_mm",
"prefix_repetition", "prefix_repetition",
"spec_bench", "spec_bench",
], ],
...@@ -1363,6 +1364,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1363,6 +1364,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
action="store_true", action="store_true",
help="Skip applying chat template to prompt for datasets that support it.", help="Skip applying chat template to prompt for datasets that support it.",
) )
parser.add_argument(
"--enable-multimodal-chat",
action="store_true",
help="Enable multimodal chat transformation for datasets that support it.",
)
parser.add_argument( parser.add_argument(
"--disable-shuffle", "--disable-shuffle",
action="store_true", action="store_true",
...@@ -1685,6 +1691,19 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: ...@@ -1685,6 +1691,19 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
) )
elif args.dataset_name == "custom_mm":
dataset = CustomMMDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
)
input_requests = dataset.sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.custom_output_len,
enable_multimodal_chat=args.enable_multimodal_chat,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
)
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
dataset = SonnetDataset( dataset = SonnetDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
...@@ -1832,6 +1851,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: ...@@ -1832,6 +1851,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.hf_output_len, output_len=args.hf_output_len,
enable_multimodal_chat=args.enable_multimodal_chat,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
skip_chat_template=args.skip_chat_template, skip_chat_template=args.skip_chat_template,
...@@ -1849,6 +1869,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: ...@@ -1849,6 +1869,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.spec_bench_output_len, output_len=args.spec_bench_output_len,
enable_multimodal_chat=args.enable_multimodal_chat,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
...@@ -1860,6 +1881,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: ...@@ -1860,6 +1881,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
enable_multimodal_chat=args.enable_multimodal_chat,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
...@@ -1903,6 +1925,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: ...@@ -1903,6 +1925,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
bucket_config=args.random_mm_bucket_config, bucket_config=args.random_mm_bucket_config,
enable_multimodal_chat=args.enable_multimodal_chat,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
...@@ -2075,6 +2098,85 @@ class CustomDataset(BenchmarkDataset): ...@@ -2075,6 +2098,85 @@ class CustomDataset(BenchmarkDataset):
return sampled_requests return sampled_requests
class CustomMMDataset(CustomDataset):
"""
Implements the Custom MultiModal dataset. Loads data from a JSONL file and generates
sample requests based on conversation turns. E.g.,
```
{
"prompt": "How many red blocks in the given images?",
"image_files": ["path/to/image1.png", "path/to/image2.png"],
}
{
"prompt": "Which country has the most pokemons based on the given graphs?",
"image_files": ["path/to/image.png"],
}
```
NOTE: Only the first image file in "image_files" is used for each sample request.
This is used to benchmark multimodal LLMs on arbitrary datasets.
"""
IS_MULTIMODAL = True
def sample(
self,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
# load all data if needed
self.num_available_samples = len(self.data)
if num_requests <= 0:
num_requests = self.num_available_samples
logger.info(
"num_requests is set to 0 or negative, "
"so using all available samples: %d",
num_requests,
)
sampled_requests = []
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["prompt"]
prompt_len = len(tokenizer(prompt).input_ids)
images = item["image_files"]
if len(images) > 1:
logger.warning(
"Multiple image files found for sample %d. "
"Only the first image will be used.",
i,
)
mm_content = process_image(images[0])
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
)
)
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix, no_oversample
)
return sampled_requests
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Spec Bench Dataset Implementation # Spec Bench Dataset Implementation
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment