datasets.py 123 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
  - ShareGPT
  - Random (synthetic)
  - Sonnet
  - BurstGPT
  - HuggingFace
  - VisionArena
"""
14

15
import argparse
16
import ast
17
18
19
20
import base64
import io
import json
import logging
21
import math
22
23
import random
from abc import ABC, abstractmethod
24
from collections.abc import Callable, Iterator, Mapping
25
from contextlib import suppress
26
from copy import deepcopy
27
28
29
from dataclasses import dataclass
from functools import cache
from io import BytesIO
30
from tempfile import NamedTemporaryFile
31
from typing import Any, cast
32
33
34

import numpy as np
from PIL import Image
35
from typing_extensions import deprecated
36
37
38
39

from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
40
from vllm.multimodal.image import convert_image_mode
41
from vllm.tokenizers import TokenizerLike
42
from vllm.utils.import_utils import PlaceholderModule
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

try:
    from datasets import load_dataset
except ImportError:
    datasets = PlaceholderModule("datasets")
    load_dataset = datasets.placeholder_attr("load_dataset")

try:
    import pandas as pd
except ImportError:
    pd = PlaceholderModule("pandas")

try:
    import librosa
except ImportError:
    librosa = PlaceholderModule("librosa")
59

60
try:
61
    from vllm.utils.argparse_utils import FlexibleArgumentParser
62
63
64
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser

65
66
67
68
69
70
71
72
73
74
75
76
77
logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------


@dataclass
class SampleRequest:
    """
    Represents a single inference request for benchmarking.
    """

78
    prompt: str | list[str]
79
80
    prompt_len: int
    expected_output_len: int
81
82
83
    multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None
    lora_request: LoRARequest | None = None
    request_id: str | None = None
84
85
86
87
88
89
90
91
92


# -----------------------------------------------------------------------------
# Benchmark Dataset Base Class
# -----------------------------------------------------------------------------


class BenchmarkDataset(ABC):
    DEFAULT_SEED = 0
93
    IS_MULTIMODAL = False
94
95
96

    def __init__(
        self,
97
        dataset_path: str | None = None,
98
        random_seed: int = DEFAULT_SEED,
99
100
        disable_shuffle: bool = False,
        **kwargs,
101
102
103
    ) -> None:
        """
        Initialize the BenchmarkDataset with an optional dataset path and random
104
105
        seed.

106
107
        Args:
            dataset_path (Optional[str]): Path to the dataset. If None, it
108
                indicates that a default or random dataset might be used.
109
            random_seed (int): Seed value for reproducible shuffling or
110
                sampling. Defaults to DEFAULT_SEED.
111
112
113
114
        """
        self.dataset_path = dataset_path
        # Set the random seed, ensuring that a None value is replaced with the
        # default seed.
115
        self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
116
        self.disable_shuffle = disable_shuffle
117
118
119
        self.data = None

    def apply_multimodal_chat_transformation(
120
121
        self,
        prompt: str,
122
        mm_content: MultiModalDataDict | dict | list[dict] | None = None,
123
    ) -> list[dict]:
124
125
126
127
128
129
130
        """
        Transform a prompt and optional multimodal content into a chat format.
        This method is used for chat models that expect a specific conversation
        format.
        """
        content = [{"text": prompt, "type": "text"}]
        if mm_content is not None:
131
132
133
134
135
            if isinstance(mm_content, list):
                content.extend(cast(list[dict[str, Any]], mm_content))
            elif isinstance(mm_content, dict):
                content.append(mm_content)
            else:
136
                raise TypeError(
137
                    f"Could not process multimodal content of type: {type(mm_content)}"
138
                )
139
140
141
142
143
144
145
146
147
148
149
150
151
        return [{"role": "user", "content": content}]

    def load_data(self) -> None:
        """
        Load data from the dataset path into self.data.

        This method must be overridden by subclasses since the method to load
        data will vary depending on the dataset format and source.

        Raises:
            NotImplementedError: If a subclass does not implement this method.
        """
        # TODO (jenniferzhao): add support for downloading data
152
        raise NotImplementedError("load_data must be implemented in subclasses.")
153
154
155

    def get_random_lora_request(
        self,
156
157
158
        max_loras: int | None = None,
        lora_path: str | None = None,
    ) -> LoRARequest | None:
159
        """
160
        Optionally select a random LoRA request.
161
162

        This method is used when LoRA parameters are provided.  It randomly
163
        selects a LoRA based on max_loras.
164
165

        Args:
166
167
168
169
            max_loras (Optional[int]): The maximum number of LoRAs available.
                If `None`, LoRA is not used.
            lora_path (Optional[str]): Path to the LoRA parameters on disk.
                If `None`, LoRA is not used.
170
171

        Returns:
172
173
            A new [`LoRARequest`][vllm.lora.request.LoRARequest]
            (or `None` if not applicable).
174
175
        """
        if max_loras is None or lora_path is None:
176
            return None
177
178
179
180
181
182
183
184

        # Generate a random LoRA ID in the range [1, max_loras].
        lora_id = random.randint(1, max_loras)
        lora_request = LoRARequest(
            lora_name=str(lora_id),
            lora_int_id=lora_id,
            lora_path=lora_path_on_disk(lora_path),
        )
185
        return lora_request
186
187

    @abstractmethod
188
189
    def sample(
        self,
190
        tokenizer: TokenizerLike,
191
192
193
194
        num_requests: int,
        request_id_prefix: str = "",
        no_oversample: bool = False,
    ) -> list[SampleRequest]:
195
196
197
198
199
200
201
        """
        Abstract method to generate sample requests from the dataset.

        Subclasses must override this method to implement dataset-specific logic
        for generating a list of SampleRequest objects.

        Args:
202
            tokenizer (TokenizerLike): The tokenizer to be used
203
                for processing the dataset's text.
204
            num_requests (int): The number of sample requests to generate.
205
            request_id_prefix (str): The prefix of request_id.
206
207
208
209
210
211
212

        Returns:
            list[SampleRequest]: A list of sample requests generated from the
            dataset.
        """
        raise NotImplementedError("sample must be implemented in subclasses.")

213
214
215
216
217
    def maybe_oversample_requests(
        self,
        requests: list[SampleRequest],
        num_requests: int,
        request_id_prefix: str = "",
218
        no_oversample: bool = False,
219
    ) -> None:
220
221
222
223
224
225
        """
        Oversamples the list of requests if its size is less than the desired
        number.

        Args:
            requests (List[SampleRequest]): The current list of sampled
226
227
                requests.
            num_requests (int): The target number of requests.
228
229
            request_id_prefix (str): The prefix applied to generated request
                identifiers.
230

231
        """
232
        if no_oversample:
233
            logger.info("Skipping oversampling. Total samples: %d.", len(requests))
234
235
            return

236
237
        if len(requests) < num_requests:
            random.seed(self.random_seed)
238
239
240
241
            needed = num_requests - len(requests)
            additional = []
            for i in range(needed):
                req = deepcopy(random.choice(requests))
242
                req.request_id = request_id_prefix + str(len(requests) + i)
243
                additional.append(req)
244
            requests.extend(additional)
245
            logger.info("Oversampled requests to reach %d total samples.", num_requests)
246

247
248
        ids = [req.request_id for req in requests]
        if len(ids) != len(set(ids)):
249
250
251
252
253
            raise ValueError(
                "Duplicate request_id found in the sampled "
                "requests. Please ensure that each request_id "
                "is unique."
            )
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

# -----------------------------------------------------------------------------
# Utility Functions and Global Caches
# -----------------------------------------------------------------------------


def is_valid_sequence(
    prompt_len: int,
    output_len: int,
    min_len: int = 4,
    max_prompt_len: int = 1024,
    max_total_len: int = 2048,
    skip_min_output_len_check: bool = False,
) -> bool:
    """
    Validate a sequence based on prompt and output lengths.

    Default pruning criteria are copied from the original `sample_hf_requests`
    and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
    from `sample_requests` in benchmark_throughput.py.
    """
    # Check for invalid conditions
    prompt_too_short = prompt_len < min_len
278
    output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
279
280
281
282
    prompt_too_long = prompt_len > max_prompt_len
    combined_too_long = (prompt_len + output_len) > max_total_len

    # Return True if none of the invalid conditions are met
283
284
285
    return not (
        prompt_too_short or output_too_short or prompt_too_long or combined_too_long
    )
286
287
288
289
290
291
292
293


@cache
def lora_path_on_disk(lora_path: str) -> str:
    return get_adapter_absolute_path(lora_path)


# Global cache for LoRA tokenizers.
294
lora_tokenizer_cache: dict[int, TokenizerLike] = {}
295
296
297
298
299
300


def process_image(image: Any) -> Mapping[str, Any]:
    """
    Process a single image input and return a multimedia content dictionary.

301
    Supports the following input types:
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

    1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
       containing raw image data.  - Loads the bytes as a PIL.Image.Image.

    2. PIL.Image.Image input: - Converts the image to RGB.  - Saves the image as
       a JPEG in memory.  - Encodes the JPEG data as a base64 string.  - Returns
       a dictionary with the image as a base64 data URL.

    3. String input: - Treats the string as a URL or local file path.  -
       Prepends "file://" if the string doesn't start with "http://" or
       "file://".  - Returns a dictionary with the image URL.

    Raises:
        ValueError: If the input is not a supported type.
    """
317
318
    if isinstance(image, dict) and "bytes" in image:
        image = Image.open(BytesIO(image["bytes"]))
319
    if isinstance(image, Image.Image):
320
        image = convert_image_mode(image, "RGB")
321
322
        with io.BytesIO() as image_data:
            image.save(image_data, format="JPEG")
323
            image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
324
325
        return {
            "type": "image_url",
326
            "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
327
328
329
        }

    if isinstance(image, str):
330
331
332
333
334
        image_url = (
            image
            if image.startswith(("http://", "https://", "file://"))
            else f"file://{image}"
        )
335
336
        return {"type": "image_url", "image_url": {"url": image_url}}

337
338
339
340
    raise ValueError(
        f"Invalid image input {image}. Must be a PIL.Image.Image"
        " or str or dictionary with raw image bytes."
    )
341
342


343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def process_video(video: Any) -> Mapping[str, Any]:
    """
    Process a single video input and return a multimedia content dictionary.

    Supports the following input types:

    1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
       containing raw video data.

    2. String input: - Treats the string as a URL or local file path.  -
       Prepends "file://" if the string doesn't start with "http://" or
       "file://".  - Returns a dictionary with the image URL.

    Raises:
        ValueError: If the input is not a supported type.
    """
359
360
    if isinstance(video, dict) and "bytes" in video:
        video_bytes = video["bytes"]
361
362
363
        video_base64 = base64.b64encode(video_bytes).decode("utf-8")
        return {
            "type": "video_url",
364
            "video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
365
366
367
        }

    if isinstance(video, str):
368
369
370
371
372
        video_url = (
            video
            if video.startswith(("http://", "https://", "file://"))
            else f"file://{video}"
        )
373
374
375
376
377
378
        return {"type": "video_url", "video_url": {"url": video_url}}

    raise ValueError(
        f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`."  # noqa: E501
    )

379
380

def gen_prompt_decode_to_target_len(
381
    tokenizer: TokenizerLike,
382
383
384
385
    token_sequence: list[int],
    target_token_len: int,
    max_retry: int = 10,
    add_special_tokens: bool = False,
386
    rng: np.random.Generator | None = None,
387
388
389
390
391
) -> tuple[str, list[int]]:
    """
    Ensure decoded-then-encoded prompt length matches the target token length.

    This function decodes an initial token sequence to text and re-encodes it
392
393
    , iteratively adjusting the token sequence length to match a target.
    This is necessary because some tokenizers do not guarantee a 1:1 mapping
394
395
396
397
398
399
400
401
402
403
404
    between consecutive tokens and the decoded-then-encoded sequence length.
    For example, for GPT2Tokenizer:
    [6880, 6881] -> ['Ġcalls', 'here'] ->
    [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']

    Returns a tuple of the final prompt string and the adjusted token sequence.
    """
    remain_num_try = max_retry
    token_mismatch = 0
    while True:
        prompt = tokenizer.decode(token_sequence)
405
        token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
406
407
408
409
        if remain_num_try <= 0:
            if len(token_sequence) != target_token_len:
                token_mismatch = len(token_sequence) - target_token_len
            break
410

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        if len(token_sequence) == target_token_len:
            break
        elif len(token_sequence) < target_token_len:
            if rng is not None:
                extra_tokens = rng.integers(
                    0,
                    tokenizer.vocab_size,
                    size=target_token_len - len(token_sequence),
                ).tolist()
            else:
                extra_tokens = np.random.randint(
                    0,
                    tokenizer.vocab_size,
                    size=target_token_len - len(token_sequence),
                ).tolist()
            token_sequence.extend(extra_tokens)
        elif len(token_sequence) > target_token_len:
            token_sequence = token_sequence[:target_token_len]

        remain_num_try -= 1

    return prompt, token_sequence, token_mismatch

434

435
436
437
438
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------

439

440
class RandomDataset(BenchmarkDataset):
441
442
443
444
445
446
447
448
449
450
451
452
    """
    Synthetic text-only dataset for serving/throughput benchmarks.

    Strategy:
    - Sample input/output token lengths per request from integer-uniform ranges
      around configured means (controlled by range_ratio).
    - Prepend a fixed random prefix of length prefix_len.
    - Generate the remaining tokens as a reproducible sequence:
      (offset + index + arange(input_len)) % vocab_size.
    - Decode then re-encode/truncate to ensure prompt token counts match.
    - Uses numpy.default_rng seeded with random_seed for reproducible sampling.
    """
453

454
455
456
457
458
459
    # Default values copied from benchmark_serving.py for the random dataset.
    DEFAULT_PREFIX_LEN = 0
    DEFAULT_RANGE_RATIO = 0.0
    DEFAULT_INPUT_LEN = 1024
    DEFAULT_OUTPUT_LEN = 128

460
    def __init__(self, **kwargs) -> None:
461
        super().__init__(**kwargs)
462
463
464
465
        # Use numpy's default_rng for deterministic sampling
        # Do not use random.seed() or np.random.seed() elsewhere in this class.
        # This ensures that the RNG is isolated from global RNG state.
        self._rng = np.random.default_rng(self.random_seed)
466
467
468

    def sample(
        self,
469
        tokenizer: TokenizerLike,
470
        num_requests: int,
471
        request_id_prefix: str = "",
472
        no_oversample: bool = False,
473
474
475
476
        prefix_len: int = DEFAULT_PREFIX_LEN,
        range_ratio: float = DEFAULT_RANGE_RATIO,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
477
        batchsize: int = 1,
478
479
        **kwargs,
    ) -> list[SampleRequest]:
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        # validate total input tokens (prefix + sampled) is at least 1.
        num_special = int(tokenizer.num_special_tokens_to_add())
        real_input_len = max(0, int(input_len) - num_special)
        min_sampled_input = math.floor(real_input_len * (1.0 - float(range_ratio)))
        min_total_input = int(prefix_len) + min_sampled_input
        if min_total_input < 1:
            raise ValueError(
                "--random-input-len is too small: with tokenizer special "
                f"tokens {num_special} and --random-range-ratio {range_ratio}, "
                "the minimum possible total input tokens (prefix + sampled) is "
                f"{min_total_input}. Increase --random-input-len and/or "
                "--random-prefix-len, or decrease --random-range-ratio so that "
                "prefix_len + floor(max(0, random_input_len - num_special)) "
                "* (1 - range_ratio) >= 1."
            )

496
497
        input_lens, output_lens, offsets = self.get_sampling_params(
            num_requests, range_ratio, input_len, output_len, tokenizer
498
499
500
        )

        vocab_size = tokenizer.vocab_size
501
502
503
504
505
506
        prohibited_tokens = tokenizer.all_special_ids
        all_tokens = np.arange(vocab_size)
        allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))

        # Generate prefix once
        prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
507

508
        requests = []
509
        token_mismatch_total = 0
510
        for i in range(num_requests):
511
            prompt, total_input_len, token_mismatch = self.generate_token_sequence(  # noqa: E501
512
513
514
515
516
517
518
                tokenizer=tokenizer,
                prefix_token_ids=prefix_token_ids,
                prefix_len=prefix_len,
                vocab_size=vocab_size,
                input_len=int(input_lens[i]),
                offset=int(offsets[i]),
                index=i,
519
                allowed_tokens=allowed_tokens,
520
            )
521
            token_mismatch_total += token_mismatch
522
523
524
525
526
527
528
529
            requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=total_input_len,
                    expected_output_len=int(output_lens[i]),
                    request_id=request_id_prefix + str(i),
                )
            )
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        # only used for embeddings benchmark.
        if batchsize > 1:
            batch_requests = []
            # Create batched requests
            for i in range(0, num_requests, batchsize):
                batch = requests[i : i + batchsize]
                batch_requests.append(
                    SampleRequest(
                        prompt=[req.prompt for req in batch],
                        prompt_len=sum(req.prompt_len for req in batch),
                        expected_output_len=0,
                        request_id=request_id_prefix + str(i // batchsize),
                    )
                )
            requests = batch_requests
545

546
547
548
549
550
551
552
553
554
555
556
        if token_mismatch_total != 0:
            sign = "more" if token_mismatch_total > 0 else "fewer"
            logger.warning(
                "Across all generated prompts, there were %d %s tokens "
                "than expected after decoding and re-encoding. This is "
                "expected due to the imperfect nature of the sampling "
                "procedure.",
                abs(token_mismatch_total),
                sign,
            )

557
558
559
        return requests

    def get_prefix(
560
561
562
        self,
        allowed_tokens: np.ndarray,
        prefix_len: int,
563
564
565
566
567
    ) -> list[int]:
        """
        Get the prefix for the dataset.
        """
        return (
568
569
570
            allowed_tokens[
                self._rng.integers(0, len(allowed_tokens), size=prefix_len)
            ].tolist()
571
572
573
            if prefix_len > 0
            else []
        )
574

575
576
577
578
579
580
    def get_sampling_params(
        self,
        num_requests: int,
        range_ratio: float,
        input_len: int,
        output_len: int,
581
        tokenizer: TokenizerLike,
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Get the sampling parameters for the dataset.
        """
        # Enforce range_ratio < 1
        if not (0.0 <= range_ratio < 1.0):
            raise ValueError("range_ratio must be in [0, 1).")
        num_special_tokens = int(tokenizer.num_special_tokens_to_add())
        real_input_len = max(0, int(input_len) - num_special_tokens)
        # Bounds use floor for low and ceil for high
        input_low = math.floor(real_input_len * (1 - range_ratio))
        input_high = math.ceil(real_input_len * (1 + range_ratio))
        output_low = math.floor(output_len * (1 - range_ratio))
        output_high = math.ceil(output_len * (1 + range_ratio))
596
597
        # Ensure the lower bound for output length is at least 1 to
        # prevent sampling 0 tokens.
598
        output_low = max(output_low, 1)
599
        output_high = max(output_high, 1)
600
601
602

        if input_low > input_high:
            raise ValueError(
603
                f"Invalid input sampling interval: low={input_low} > high={input_high}"
604
605
606
607
608
609
            )
        if output_low > output_high:
            raise ValueError(
                "Invalid output sampling interval: "
                f"low={output_low} > high={output_high}"
            )
610

611
612
        logger.info(
            "Sampling input_len from [%s, %s] and output_len from [%s, %s]",
613
614
615
616
617
            input_low,
            input_high,
            output_low,
            output_high,
        )
618

619
620
621
        input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests)
        output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests)
        offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests)
622
        return input_lens, output_lens, offsets
623

624
625
626
    def generate_token_sequence(
        self,
        *,
627
        tokenizer: TokenizerLike,
628
629
630
631
632
633
        prefix_token_ids: list[int],
        prefix_len: int,
        vocab_size: int,
        input_len: int,
        offset: int,
        index: int,
634
        allowed_tokens: np.ndarray,
635
    ) -> tuple[str, int, int]:
636
637
638
639
640
641
642
643
644
645
        """
        Returns (prompt, total_input_len).

        NOTE: After decoding the prompt we have to encode and decode it again.
        This is done because in some cases N consecutive tokens
        give a string tokenized into != N number of tokens.
        For example for GPT2Tokenizer:
        [6880, 6881] -> ['Ġcalls', 'here'] ->
        [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
        To avoid uncontrolled change of the prompt length,
646
        the encoded sequence is truncated before being decoded again.
647
        """
648
649
650
651
652
        # Build the inner sequence by sampling
        # sequentially from the allowed tokens
        inner_seq = allowed_tokens[
            (offset + index + np.arange(input_len)) % len(allowed_tokens)
        ].tolist()
653
654
655
656
        token_sequence = prefix_token_ids + inner_seq

        # Decode, then re-encode and truncate to preserve token count invariants
        total_input_len = prefix_len + int(input_len)
657
        prompt, adjusted_token_sequence, token_mismatch = (
658
            gen_prompt_decode_to_target_len(
659
660
661
662
663
664
                tokenizer=tokenizer,
                token_sequence=token_sequence,
                target_token_len=total_input_len,
                add_special_tokens=False,
                rng=self._rng,
            )
665
666
667
        )
        total_input_len = len(adjusted_token_sequence)
        return prompt, total_input_len, token_mismatch
668
669


670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------


class RandomDatasetForReranking(RandomDataset):
    """
    Random dataset specialized for the needs of scoring:
    - Batches of inputs
    - Inputs composed of pairs
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def sample(
        self,
687
        tokenizer: TokenizerLike,
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        num_requests: int,
        request_id_prefix: str = "",
        range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
        input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
        batchsize: int = 1,
        is_reranker: bool = True,
        **kwargs,
    ) -> list[SampleRequest]:
        n_sep_tokens = int(is_reranker)

        query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len

        query_lens, _, query_offsets = self.get_sampling_params(
            1, range_ratio, query_len_param, 0, tokenizer
        )

        query_len = int(query_lens[0])

        if not is_reranker:
            assert num_requests > 1 and batchsize > 1
            num_requests -= 1
            batchsize -= 1
            doc_len_param = input_len
        else:
            doc_len_param = input_len - query_len - n_sep_tokens

        doc_lens, _, doc_offsets = self.get_sampling_params(
            num_requests, range_ratio, doc_len_param, 0, tokenizer
        )
717

718
        vocab_size = tokenizer.vocab_size
719
720
721
        prohibited_tokens = tokenizer.all_special_ids
        all_tokens = np.arange(vocab_size)
        allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
722
723
724
725
726
727
728
729
730
731

        query_prompt, query_input_len, token_mismatch_total = (
            self.generate_token_sequence(
                tokenizer=tokenizer,
                prefix_token_ids=[],
                prefix_len=0,
                vocab_size=vocab_size,
                input_len=query_len,
                offset=int(query_offsets[0]),
                index=0,
732
                allowed_tokens=allowed_tokens,
733
734
735
736
737
738
739
740
741
742
743
744
745
            )
        )

        requests = []
        for i in range(num_requests):
            prompt, total_input_len, token_mismatch = self.generate_token_sequence(  # noqa: E501
                tokenizer=tokenizer,
                prefix_token_ids=[],
                prefix_len=0,
                vocab_size=vocab_size,
                input_len=int(doc_lens[i]),
                offset=int(doc_offsets[i]),
                index=i + 1,
746
                allowed_tokens=allowed_tokens,
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
            )
            token_mismatch_total += token_mismatch
            requests.append((prompt, total_input_len))

        batch_requests = []
        # Create batched requests
        for i in range(0, num_requests, batchsize):
            batch = requests[i : i + batchsize]
            query_contrib = (
                (query_input_len + n_sep_tokens) * len(batch)
                if is_reranker
                else query_input_len
            )
            batch_requests.append(
                SampleRequest(
                    prompt=[query_prompt] + [req[0] for req in batch],
                    prompt_len=query_contrib + sum(req[1] for req in batch),
                    expected_output_len=0,
                    request_id=request_id_prefix + str(i // batchsize),
                )
            )

        if token_mismatch_total != 0:
            logger.warning(
                "Across all generated prompts, there were %d %s tokens "
                "than expected after decoding and re-encoding. This is "
                "expected due to the imperfect nature of the sampling "
                "procedure.",
                abs(token_mismatch_total),
                "more" if token_mismatch_total > 0 else "fewer",
            )

        return batch_requests


782
783
784
785
# -----------------------------------------------------------------------------
# MultiModalDataset Implementation
# -----------------------------------------------------------------------------

786

787
788
789
790
791
792
class RandomMultiModalDataset(RandomDataset):
    """
    Synthetic multimodal dataset (text + images) that extends RandomDataset.

    Status:
    - Images: supported via synthetic RGB data.
793
    - Video: supported via synthetic RGB data.
794
795
796
797
798
799
800
801
    - Audio: not yet supported.

    Sampling overview:
    1) Number of items per request is sampled uniformly from the integer range
       [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is
       `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
       The maximum is further clamped to the sum of per-modality limits.
    2) Each item’s modality and shape is sampled from `bucket_config`, a dict
802
       mapping (height, width, num_frames) → probability. We treat
803
       `num_frames`=1 as image and `num_frames` > 1 as video.
804
       Entries with zero probability are removed and the rest are renormalized
805
806
807
808
809
810
811
       to sum to 1.
    3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
       When a modality reaches its cap, all of its buckets are excluded and the
       remaining probabilities are renormalized.

    Example bucket configuration:
    {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
812
813
      - Two image buckets (`num_frames`=1) and one video bucket
      (`num_frames`=16).
814
815
816
817
    OBS.: Only image sampling is supported for now.
    """

    IS_MULTIMODAL = True
818
    DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 1}
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833

    DEFAULT_BASE_ITEMS_PER_REQUEST = 1
    DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
    DEFAULT_MM_ITEM_BUCKET_CONFIG = {
        (256, 256, 1): 0.5,
        (720, 1280, 1): 0.5,
        (720, 1280, 16): 0.0,
    }
    DEFAULT_ENABLE_MULTIMODAL_CHAT = False

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
        """Generate synthetic PIL image with random RGB values.
834
835
836

        NOTE: iid pixel sampling results in worst-case compression
        (good for stressing I/O), but very unlike real photos.
837
838
839
840
841
842
843
844
845
846
847
        We could consider a “low-freq” mode (e.g., noise blur)
        to emulate network realism instead of max stress.
        """
        random_pixels = self._rng.integers(
            0,
            256,
            (height, width, 3),
            dtype=np.uint8,
        )
        return Image.fromarray(random_pixels)

848
849
850
    def generate_synthetic_video(
        self, width: int, height: int, num_frames: int
    ) -> dict:
851
        """Generate synthetic video with random values.
852

853
854
        Creates a video with random pixel values, encodes it to MP4 format,
        and returns the content as bytes.
855
        """
856
857
        import cv2

858
859
860
861
862
863
864
865
866
867
868
        random_pixels = self._rng.integers(
            0,
            256,
            (num_frames, height, width, 3),
            dtype=np.uint8,
        )

        # Create a temporary video file in memory
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        fps = 30  # frames per second

869
        with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
            temp_path = temp_file.name

            # Create video writer
            video_writer = cv2.VideoWriter(
                temp_path, fourcc=fourcc, fps=fps, frameSize=(width, height)
            )

            if not video_writer.isOpened():
                raise RuntimeError("Failed to create video writer")

            for frame in random_pixels:
                video_writer.write(frame)

            video_writer.release()
            temp_file.close()

            # Read the video file content
            with open(temp_path, "rb") as f:
                video_content = f.read()

            return {"bytes": video_content}
891
892
893
894
895
896
897
898
899
900

    def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
        """Map the configuration to the modality."""
        if config[-1] == 1:
            return "image"
        elif config[-1] > 1:
            return "video"
        else:
            raise ValueError(f"Invalid multimodal item configuration: {config}")

901
902
903
    def normalize_bucket_config(
        self, bucket_config: dict[tuple[int, int, int], float]
    ) -> dict[tuple[int, int, int], float]:
904
905
906
907
908
909
910
911
912
913
914
        """
        Remove zero probability entries
        and normalize the bucket config to sum to 1.
        """
        # Raise error if value is negative
        if any(v < 0 for v in bucket_config.values()):
            raise ValueError("Bucket config values must be non-negative.")
        # Remove zero probability entries
        bucket_config = {k: v for k, v in bucket_config.items() if v > 0}
        # if bucket config is empty, raise error
        if not bucket_config:
915
916
917
            raise ValueError(
                "Got invalid bucket config. Bucket config values must be non-zero."
            )
918
919
920
921
        # Normalize the remaining bucket config to sum to 1
        total = sum(bucket_config.values())
        return {k: v / total for k, v in bucket_config.items()}

922
923
924
925
    def generate_mm_item(
        self,
        mm_item_config: tuple[int, int, int],
    ) -> Mapping[str, Any]:
926
        """
927
        Create synthetic images and videos and
928
929
930
931
        apply process_image/process_video respectively.
        This follows the OpenAI API chat completions
        https://github.com/openai/openai-python
        """
932

933
        if self.map_config_to_modality(mm_item_config) == "image":
934
935
936
            return process_image(
                self.generate_synthetic_image(mm_item_config[1], mm_item_config[0])
            )
937
        elif self.map_config_to_modality(mm_item_config) == "video":
938
939
940
941
942
            return process_video(
                self.generate_synthetic_video(
                    mm_item_config[1], mm_item_config[0], mm_item_config[2]
                )
            )
943
        else:
944
            raise ValueError(f"Invalid multimodal item configuration: {mm_item_config}")
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964

    def get_mm_item_sampling_params(
        self,
        base_items_per_request: int,
        num_mm_items_range_ratio: float,
        limit_mm_per_prompt: dict[str, int],
        bucket_config: dict[tuple[int, int, int], float],
    ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]:
        """
        Get the sampling parameters for the multimodal items.
        """
        # Enforce num_mm_items_range_ratio <= 1
        if not (0.0 <= num_mm_items_range_ratio <= 1.0):
            raise ValueError("num_mm_items_range_ratio must be in [0, 1].")

        # Ensure modalities to sample are in limit_mm_per_prompt
        for k, v in bucket_config.items():
            # get modality from bucket config
            modality = self.map_config_to_modality(k)
            if modality not in limit_mm_per_prompt:
965
966
967
968
969
                raise ValueError(
                    f"Modality {modality} is not in "
                    f"limit_mm_per_prompt: "
                    f"{limit_mm_per_prompt.keys()}"
                )
970

971
        # Remove zero probability entries
972
973
974
        # and normalize bucket config to sum to 1
        bucket_config = self.normalize_bucket_config(bucket_config)
        logger.info(
975
976
            "Normalized bucket config: %s",
            bucket_config,
977
978
        )
        # Only consider limit per prompt for modalities in bucket config
979
        allowed_modalities = {self.map_config_to_modality(cfg) for cfg in bucket_config}
980
        limit_mm_per_prompt = {
981
982
            k: v for k, v in limit_mm_per_prompt.items() if k in allowed_modalities
        }
983
        if not limit_mm_per_prompt:
984
            raise ValueError("No valid limits for modalities present in bucket_config.")
985
986

        logger.info(
987
988
            "Updated mm-limit-per-prompt: %s",
            limit_mm_per_prompt,
989
990
991
992
993
        )

        # Get max and min num mm items and ensure
        # it is at most the sum of limit_mm_per_prompt for all modalities
        max_num_mm_items = min(
994
            sum(limit_mm_per_prompt.values()),
995
            math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)),
996
997
998
        )
        # Ensure min num mm items is at least 0
        min_num_mm_items = max(
999
            0, math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
1000
1001
1002
        )
        # Raise error if min num mm items is greater than max num mm items
        if min_num_mm_items > max_num_mm_items:
1003
1004
1005
1006
            raise ValueError(
                f"Min num mm items is greater than max mm items: "
                f"{min_num_mm_items} > {max_num_mm_items}"
            )
1007

1008
1009
        logger.info(
            "Sampling number of multimodal items from [%s, %s]",
1010
1011
            min_num_mm_items,
            max_num_mm_items,
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        )

        return (
            min_num_mm_items,
            max_num_mm_items,
            limit_mm_per_prompt,
            bucket_config,
        )

    def get_mm_item_iterator(
        self,
        min_num_mm_items: int,
        max_num_mm_items: int,
        bucket_config: dict[tuple[int, int, int], float],
        limit_mm_per_prompt: dict[str, int],
1027
    ) -> Iterator[tuple[int, int, int]]:
1028
1029
1030
1031
1032
        """
        Iterator over the multimodal items for each request
        whose size is between min_num_mm_items and max_num_mm_items.

        Loop over the bucket config and sample a multimodal item.
1033
1034
        Loop until the number of multimodal items sampled is equal to
        request_num_mm_items or limit of multimodal items per prompt
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        for all modalities is reached.

        Note:
        - This function operates on a per-request shallow copy of
          `bucket_config` (tuple->float). The original dict passed to
          `sample` is not mutated. If this ever changes, a test
          is implemented and will fail.
        """
        # Get the number of multimodal items to sample
        request_num_mm_items = int(
            self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
1046
        )
1047
1048
1049
1050
        # If request_num_mm_items is 0, yield an empty iterator
        if request_num_mm_items == 0:
            return
        # Initialize modality counters
1051
        modality_counter = {self.map_config_to_modality(k): 0 for k in bucket_config}
1052
1053
1054
1055
1056
        # Copy the bucket config to avoid modifying the original
        bucket_config_copy = bucket_config.copy()
        # Loop over the number of multimodal items to sample
        while sum(modality_counter.values()) < request_num_mm_items:
            # Sample a multimodal item config
1057
1058
1059
            mm_item_config = self._rng.choice(
                list(bucket_config_copy.keys()), p=list(bucket_config_copy.values())
            )
1060
1061
1062
1063
            modality = self.map_config_to_modality(mm_item_config)
            # Check that modality count is less than limit per prompt
            if modality_counter[modality] < limit_mm_per_prompt[modality]:
                modality_counter[modality] += 1
1064
                yield (mm_item_config)
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
            else:
                # If the counter is greater than the limit per prompt
                # set all multimodal items of this modality to 0
                for k, v in bucket_config_copy.items():
                    if self.map_config_to_modality(k) == modality:
                        bucket_config_copy[k] = 0
                # If all configs are 0, break the loop
                # This should not happen as request_num_mm_items is at most
                # the sum of limit_mm_per_prompt for all modalities
                if all(v == 0 for v in bucket_config_copy.values()):
1075
1076
1077
                    logger.warning(
                        "Exhausted all multimodal items of modality %s", modality
                    )
1078
1079
                    break
                # Renormalize the bucket config
1080
                bucket_config_copy = self.normalize_bucket_config(bucket_config_copy)
1081
1082
1083

    def sample(
        self,
1084
        tokenizer: TokenizerLike,
1085
1086
        num_requests: int,
        request_id_prefix: str = "",
1087
        no_oversample: bool = False,
1088
1089
1090
1091
1092
1093
1094
        prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
        range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
        input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
        output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
        limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
        base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
        num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
1095
1096
1097
        bucket_config: dict[
            tuple[int, int, int], float
        ] = DEFAULT_MM_ITEM_BUCKET_CONFIG,
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
        enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
        **kwargs,
    ) -> list[SampleRequest]:
        # Get the sampling parameters for the dataset
        input_lens, output_lens, offsets = self.get_sampling_params(
            num_requests, range_ratio, input_len, output_len, tokenizer
        )

        (
            min_num_mm_items,
            max_num_mm_items,
            limit_mm_per_prompt,
            bucket_config,
        ) = self.get_mm_item_sampling_params(
            base_items_per_request,
            num_mm_items_range_ratio,
            limit_mm_per_prompt,
            bucket_config,
        )

        vocab_size = tokenizer.vocab_size
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
        # Can't use tokenizer.all_special_ids since
        # it returns ONLY ids from special_tokens_map.json
        # We want to exclude placeholder tokens and all
        # tokens that indicate start/end of image as it
        # may break prompt replacement logic.
        prohibited_tokens = list(
            tok_id
            for tok_id, token in tokenizer.added_tokens_decoder.items()
            if token.special
        )
        all_tokens = np.arange(vocab_size)
        allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
        logger.debug(
            "Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size
        )
        # Generate prefix once
        prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
1136
1137
        # Add synthetic multimodal items to each request
        mm_requests = []
1138
        token_mismatch_total = 0
1139
        for i in range(num_requests):
1140
            prompt, total_input_len, token_mismatch = self.generate_token_sequence(  # noqa: E501
1141
1142
1143
1144
1145
1146
1147
                tokenizer=tokenizer,
                prefix_token_ids=prefix_token_ids,
                prefix_len=prefix_len,
                vocab_size=vocab_size,
                input_len=int(input_lens[i]),
                offset=int(offsets[i]),
                index=i,
1148
                allowed_tokens=allowed_tokens,
1149
            )
1150
            token_mismatch_total += token_mismatch
1151
1152
1153
1154
1155
1156
1157
1158
            # Get multimodal item iterator for a given request
            mm_item_iterator = self.get_mm_item_iterator(
                min_num_mm_items,
                max_num_mm_items,
                bucket_config,
                limit_mm_per_prompt,
            )

1159
1160
1161
1162
1163
1164
1165
            mm_content = cast(
                list[dict[str, Any]],
                [
                    self.generate_mm_item(mm_item_config)
                    for mm_item_config in mm_item_iterator
                ],
            )
1166
1167

            if enable_multimodal_chat:
1168
                # NOTE: For now this option is only provided for completeness
1169
1170
1171
                # given that the serve.py benchmark currently does not use it.
                mm_chat_prompt: Any = prompt
                mm_chat_prompt = self.apply_multimodal_chat_transformation(
1172
1173
                    prompt, mm_content
                )
1174
1175
1176
1177
1178
1179
1180
1181
1182
                sample_request = SampleRequest(
                    prompt=mm_chat_prompt,
                    prompt_len=total_input_len,
                    expected_output_len=int(output_lens[i]),
                    multi_modal_data=None,
                    request_id=request_id_prefix + str(i),
                )
            else:
                sample_request = SampleRequest(
1183
1184
1185
                    prompt=prompt,
                    prompt_len=total_input_len,
                    expected_output_len=int(output_lens[i]),
1186
                    multi_modal_data=mm_content,
1187
                    request_id=request_id_prefix + str(i),
1188
1189
                )
            mm_requests.append(sample_request)
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201

        if token_mismatch_total != 0:
            sign = "more" if token_mismatch_total > 0 else "fewer"
            logger.warning(
                "Across all generated prompts, there were %d %s tokens "
                "than expected after decoding and re-encoding. This is "
                "expected due to the imperfect nature of the sampling "
                "procedure.",
                abs(token_mismatch_total),
                sign,
            )

1202
        return mm_requests
1203

1204

1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
# -----------------------------------------------------------------------------


class ShareGPTDataset(BenchmarkDataset):
    """
    Implements the ShareGPT dataset.  Loads data from a JSON file and generates
    sample requests based on conversation turns.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = json.load(f)
        # Filter entries with at least two conversation turns.
        self.data = [
1228
1229
            entry
            for entry in self.data
1230
1231
1232
            if "conversations" in entry and len(entry["conversations"]) >= 2
        ]
        random.seed(self.random_seed)
1233
1234
        if not getattr(self, "disable_shuffle", False):
            random.shuffle(self.data)
1235
1236
1237

    def sample(
        self,
1238
        tokenizer: TokenizerLike,
1239
        num_requests: int,
1240
1241
1242
        lora_path: str | None = None,
        max_loras: int | None = None,
        output_len: int | None = None,
1243
        enable_multimodal_chat: bool = False,
1244
        request_id_prefix: str = "",
1245
        no_oversample: bool = False,
1246
1247
1248
        **kwargs,
    ) -> list:
        samples: list = []
1249
        ind = 0
1250
1251
1252
1253
1254
1255
1256
1257
        for entry in self.data:
            if len(samples) >= num_requests:
                break
            prompt, completion = (
                entry["conversations"][0]["value"],
                entry["conversations"][1]["value"],
            )

1258
            lora_request = self.get_random_lora_request(
1259
1260
                max_loras=max_loras, lora_path=lora_path
            )
1261
1262
1263
            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
1264
1265
1266
1267
1268
1269
            new_output_len = len(completion_ids) if output_len is None else output_len
            if not is_valid_sequence(
                prompt_len,
                new_output_len,
                skip_min_output_len_check=output_len is not None,
            ):
1270
                continue
1271
1272
1273
            if image_path := entry.get("image"):
                mm_content = process_image(image_path)
            elif video_path := entry.get("video"):
1274
                mm_content = process_video(video_path)
1275
            else:
1276
                mm_content = None
1277
            if enable_multimodal_chat:
1278
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
1279
1280
1281
1282
1283
1284
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=new_output_len,
                    lora_request=lora_request,
1285
                    multi_modal_data=mm_content,
1286
                    request_id=request_id_prefix + str(ind),
1287
1288
                )
            )
1289
            ind += 1
1290
1291
1292
        self.maybe_oversample_requests(
            samples, num_requests, request_id_prefix, no_oversample
        )
1293
1294
1295
        return samples


1296
1297
class _ValidateDatasetArgs(argparse.Action):
    """Argparse action to validate dataset name and path compatibility."""
1298

1299
1300
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, values)
1301

1302
        # Get current values of both dataset_name and dataset_path
1303
1304
        dataset_name = getattr(namespace, "dataset_name", "random")
        dataset_path = getattr(namespace, "dataset_path", None)
1305

1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
        # Validate the combination
        if dataset_name == "random" and dataset_path is not None:
            parser.error(
                "Cannot use 'random' dataset with --dataset-path. "
                "Please specify the appropriate --dataset-name (e.g., "
                "'sharegpt', 'custom', 'sonnet') for your dataset file: "
                f"{dataset_path}"
            )


1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
def add_dataset_parser(parser: FlexibleArgumentParser):
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process.",
    )
    parser.add_argument(
        "--dataset-name",
        type=str,
        default="random",
1328
        action=_ValidateDatasetArgs,
1329
        choices=[
1330
1331
1332
1333
1334
            "sharegpt",
            "burstgpt",
            "sonnet",
            "random",
            "random-mm",
1335
            "random-rerank",
1336
1337
            "hf",
            "custom",
1338
            "custom_mm",
1339
1340
            "prefix_repetition",
            "spec_bench",
1341
        ],
1342
1343
        help="Name of the dataset to benchmark on.",
    )
1344
1345
1346
1347
1348
    parser.add_argument(
        "--no-stream",
        action="store_true",
        help="Do not load the dataset in streaming mode.",
    )
1349
1350
1351
1352
    parser.add_argument(
        "--dataset-path",
        type=str,
        default=None,
1353
        action=_ValidateDatasetArgs,
1354
1355
1356
        help="Path to the sharegpt/sonnet dataset. "
        "Or the huggingface dataset ID if using HF dataset.",
    )
1357
1358
1359
    parser.add_argument(
        "--no-oversample",
        action="store_true",
1360
        help="Do not oversample if the dataset has fewer samples than num-prompts.",
1361
    )
1362
1363
1364
    parser.add_argument(
        "--skip-chat-template",
        action="store_true",
1365
        help="Skip applying chat template to prompt for datasets that support it.",
1366
    )
1367
1368
1369
1370
1371
    parser.add_argument(
        "--enable-multimodal-chat",
        action="store_true",
        help="Enable multimodal chat transformation for datasets that support it.",
    )
1372
1373
1374
1375
1376
    parser.add_argument(
        "--disable-shuffle",
        action="store_true",
        help="Disable shuffling of dataset samples for deterministic ordering.",
    )
1377
1378
1379
1380
1381
1382
1383

    # group for dataset specific arguments
    custom_group = parser.add_argument_group("custom dataset options")
    custom_group.add_argument(
        "--custom-output-len",
        type=int,
        default=256,
1384
1385
1386
        help="Number of output tokens per request. Unless it is set to -1, the "
        "value overrides potential output length loaded from the dataset. It is "
        "used only for custom dataset.",
1387
1388
    )

1389
1390
1391
1392
1393
    spec_bench_group = parser.add_argument_group("spec bench dataset options")
    spec_bench_group.add_argument(
        "--spec-bench-output-len",
        type=int,
        default=256,
1394
        help="Num of output tokens per request, used only for spec bench dataset.",
1395
1396
1397
1398
1399
    )
    spec_bench_group.add_argument(
        "--spec-bench-category",
        type=str,
        default=None,
1400
        help="Category for spec bench dataset. If None, use all categories.",
1401
1402
    )

1403
1404
1405
1406
1407
    sonnet_group = parser.add_argument_group("sonnet dataset options")
    sonnet_group.add_argument(
        "--sonnet-input-len",
        type=int,
        default=550,
1408
        help="Number of input tokens per request, used only for sonnet dataset.",
1409
1410
1411
1412
1413
    )
    sonnet_group.add_argument(
        "--sonnet-output-len",
        type=int,
        default=150,
1414
        help="Number of output tokens per request, used only for sonnet dataset.",
1415
1416
1417
1418
1419
    )
    sonnet_group.add_argument(
        "--sonnet-prefix-len",
        type=int,
        default=200,
1420
        help="Number of prefix tokens per request, used only for sonnet dataset.",
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
    )

    sharegpt_group = parser.add_argument_group("sharegpt dataset options")
    sharegpt_group.add_argument(
        "--sharegpt-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output length "
        "from the ShareGPT dataset.",
    )

1432
1433
1434
1435
1436
    blazedit_group = parser.add_argument_group("blazedit dataset options")
    blazedit_group.add_argument(
        "--blazedit-min-distance",
        type=float,
        default=0.0,
1437
        help="Minimum distance for blazedit dataset. Min: 0, Max: 1.0",
1438
1439
1440
1441
1442
    )
    blazedit_group.add_argument(
        "--blazedit-max-distance",
        type=float,
        default=1.0,
1443
        help="Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
1444
1445
    )

1446
    random_group = parser.add_argument_group("random dataset options")
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
    add_random_dataset_base_args(random_group)

    random_mm_group = parser.add_argument_group(
        "random multimodal dataset options extended from random dataset"
    )
    add_random_multimodal_dataset_args(random_mm_group)

    hf_group = parser.add_argument_group("hf dataset options")
    hf_group.add_argument(
        "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
    )
    hf_group.add_argument(
        "--hf-split", type=str, default=None, help="Split of the HF dataset."
    )
    hf_group.add_argument(
        "--hf-name",
        type=str,
        default=None,
        help=(
            "Name of the dataset on HuggingFace "
            "(e.g., 'lmarena-ai/VisionArena-Chat'). "
            "Specify this if your dataset-path is a local path."
        ),
    )
    hf_group.add_argument(
        "--hf-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output lengths "
        "from the sampled HF dataset.",
    )

    prefix_repetition_group = parser.add_argument_group(
        "prefix repetition dataset options"
    )
    prefix_repetition_group.add_argument(
        "--prefix-repetition-prefix-len",
        type=int,
        default=256,
        help="Number of prefix tokens per request, used only for prefix "
        "repetition dataset.",
    )
    prefix_repetition_group.add_argument(
        "--prefix-repetition-suffix-len",
        type=int,
        default=256,
        help="Number of suffix tokens per request, used only for prefix "
        "repetition dataset. Total input length is prefix_len + suffix_len.",
    )
    prefix_repetition_group.add_argument(
        "--prefix-repetition-num-prefixes",
        type=int,
        default=10,
        help="Number of prefixes to generate, used only for prefix repetition "
        "dataset. Prompts per prefix is num_requests // num_prefixes.",
    )
    prefix_repetition_group.add_argument(
        "--prefix-repetition-output-len",
        type=int,
        default=128,
        help="Number of output tokens per request, used only for prefix "
        "repetition dataset.",
    )


def add_random_dataset_base_args(
    parser_or_group: FlexibleArgumentParser | argparse._ArgumentGroup,
) -> None:
    """Add CLI arguments for base random dataset options.

    This function adds arguments needed for:
    - random (random dataset)
    - random-mm (random multimodal dataset)
    - random-rerank (random dataset for reranking)

    Args:
        parser_or_group: Either a parser or an argument group to add arguments to.
    """
    parser_or_group.add_argument(
1526
1527
1528
        "--random-input-len",
        type=int,
        default=1024,
1529
        help="Number of input tokens per request, used only for random sampling.",
1530
    )
1531
    parser_or_group.add_argument(
1532
1533
1534
        "--random-output-len",
        type=int,
        default=128,
1535
        help="Number of output tokens per request, used only for random sampling.",
1536
    )
1537
    parser_or_group.add_argument(
1538
1539
1540
1541
1542
1543
1544
1545
        "--random-range-ratio",
        type=float,
        default=0.0,
        help="Range ratio for sampling input/output length, "
        "used only for random sampling. Must be in the range [0, 1) to define "
        "a symmetric sampling range"
        "[length * (1 - range_ratio), length * (1 + range_ratio)].",
    )
1546
    parser_or_group.add_argument(
1547
1548
1549
        "--random-prefix-len",
        type=int,
        default=0,
1550
1551
1552
1553
1554
1555
1556
1557
        help=(
            "Number of fixed prefix tokens before the random context "
            "in a request. "
            "The total input length is the sum of `random-prefix-len` and "
            "a random "
            "context length sampled from [input_len * (1 - range_ratio), "
            "input_len * (1 + range_ratio)]."
        ),
1558
    )
1559
    parser_or_group.add_argument(
1560
1561
1562
        "--random-batch-size",
        type=int,
        default=1,
1563
        help=("Batch size for random sampling. Only used for embeddings benchmark."),
1564
    )
1565
    parser_or_group.add_argument(
1566
1567
1568
1569
1570
1571
1572
        "--no-reranker",
        action="store_true",
        help=(
            "Whether the model supports reranking natively."
            " Only used for reranker benchmark."
        ),
    )
1573

1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586

def add_random_multimodal_dataset_args(
    parser_or_group: FlexibleArgumentParser | argparse._ArgumentGroup,
) -> None:
    """Add CLI arguments for random multimodal dataset options.

    This function adds arguments needed for:
    - random-mm (random multimodal dataset)

    Args:
        parser_or_group: Either a parser or an argument group to add arguments to.
    """
    parser_or_group.add_argument(
1587
1588
1589
1590
1591
1592
1593
1594
1595
        "--random-mm-base-items-per-request",
        type=int,
        default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
        help=(
            "Base number of multimodal items per request for random-mm. "
            "Actual per-request count is sampled around this base using "
            "--random-mm-num-mm-items-range-ratio."
        ),
    )
1596
    parser_or_group.add_argument(
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
        "--random-mm-num-mm-items-range-ratio",
        type=float,
        default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
        help=(
            "Range ratio r in [0, 1] for sampling items per request. "
            "We sample uniformly from the closed integer range "
            "[floor(n*(1-r)), ceil(n*(1+r))] "
            "where n is the base items per request. "
            "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
            "to the sum of per-modality limits from "
            "--random-mm-limit-mm-per-prompt. "
            "An error is raised if the computed min exceeds the max."
        ),
    )
1611
    parser_or_group.add_argument(
1612
1613
1614
1615
1616
        "--random-mm-limit-mm-per-prompt",
        type=json.loads,
        default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
        help=(
            "Per-modality hard caps for items attached per request, e.g. "
1617
            '\'{"image": 3, "video": 0}\'. The sampled per-request item '
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
            "count is clamped to the sum of these limits. When a modality "
            "reaches its cap, its buckets are excluded and probabilities are "
            "renormalized."
            "OBS.: Only image sampling is supported for now."
        ),
    )

    def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]:
        # If already a dict (e.g., programmatic call), normalize keys
        def normalize(d: dict) -> dict[tuple[int, int, int], float]:
            out: dict[tuple[int, int, int], float] = {}
            for k, val in d.items():
                key = k
                if isinstance(key, str):
                    with suppress(Exception):
                        key = ast.literal_eval(key)
1634
1635
1636
1637
1638
                if not (
                    isinstance(key, tuple)
                    and len(key) == 3
                    and all(isinstance(x, int) for x in key)
                ):
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
                    raise ValueError(
                        f"Invalid bucket key {k!r}. Expected tuple (H, W, T)."
                    )
                out[(int(key[0]), int(key[1]), int(key[2]))] = float(val)
            return out

        if isinstance(v, dict):
            return normalize(v)
        if isinstance(v, str):
            # Python literal (supports tuple keys)
            parsed = ast.literal_eval(v)
            if not isinstance(parsed, dict):
                raise ValueError("Bucket config must parse to a dict.")
            return normalize(parsed)
        raise ValueError("Unsupported value for --random-mm-bucket-config.")

1655
    parser_or_group.add_argument(
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
        "--random-mm-bucket-config",
        type=_parse_mm_bucket_config,
        default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
        help=(
            "The bucket config is a dictionary mapping a multimodal item"
            "sampling configuration to a probability."
            "Currently allows for 2 modalities: images and videos. "
            "An bucket key is a tuple of (height, width, num_frames)"
            "The value is the probability of sampling that specific item. "
            "Example: "
            "--random-mm-bucket-config "
            "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
            "First item: images with resolution 256x256 w.p. 0.5"
            "Second item: images with resolution 720x1280 w.p. 0.4 "
            "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
            "OBS.: If the probabilities do not sum to 1, they are normalized."
            "OBS bis.: Only image sampling is supported for now."
        ),
1674
1675
    )

1676

1677
def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
1678
1679
1680
    if not hasattr(args, "request_id_prefix"):
        args.request_id_prefix = ""

1681
    if args.dataset_name == "custom":
1682
1683
1684
        dataset = CustomDataset(
            dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
        )
1685
1686
1687
1688
        input_requests = dataset.sample(
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            output_len=args.custom_output_len,
1689
            skip_chat_template=args.skip_chat_template,
1690
            request_id_prefix=args.request_id_prefix,
1691
            no_oversample=args.no_oversample,
1692
1693
        )

1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
    elif args.dataset_name == "custom_mm":
        dataset = CustomMMDataset(
            dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
        )
        input_requests = dataset.sample(
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            output_len=args.custom_output_len,
            enable_multimodal_chat=args.enable_multimodal_chat,
            request_id_prefix=args.request_id_prefix,
            no_oversample=args.no_oversample,
        )

1707
    elif args.dataset_name == "sonnet":
1708
1709
1710
        dataset = SonnetDataset(
            dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
        )
1711
        # For the "sonnet" dataset, formatting depends on the backend.
1712
        if args.backend == "openai-chat":
1713
1714
1715
1716
1717
1718
1719
            input_requests = dataset.sample(
                num_requests=args.num_prompts,
                input_len=args.sonnet_input_len,
                output_len=args.sonnet_output_len,
                prefix_len=args.sonnet_prefix_len,
                tokenizer=tokenizer,
                return_prompt_formatted=False,
1720
                request_id_prefix=args.request_id_prefix,
1721
                no_oversample=args.no_oversample,
1722
1723
1724
            )
        else:
            assert tokenizer.chat_template or tokenizer.default_chat_template, (
1725
1726
                "Tokenizer/model must have chat template for sonnet dataset."
            )
1727
1728
1729
1730
1731
1732
1733
            input_requests = dataset.sample(
                num_requests=args.num_prompts,
                input_len=args.sonnet_input_len,
                output_len=args.sonnet_output_len,
                prefix_len=args.sonnet_prefix_len,
                tokenizer=tokenizer,
                return_prompt_formatted=True,
1734
                request_id_prefix=args.request_id_prefix,
1735
                no_oversample=args.no_oversample,
1736
1737
1738
1739
1740
            )

    elif args.dataset_name == "hf":
        # all following datasets are implemented from the
        # HuggingFaceDataset base class
1741
        hf_kwargs = {}
1742
1743
1744
1745
        if (
            args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
        ):
1746
1747
1748
            dataset_class = VisionArenaDataset
            args.hf_split = "train"
            args.hf_subset = None
1749
1750
1751
1752
1753
1754
1755
        elif (
            args.dataset_path in MMVUDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in MMVUDataset.SUPPORTED_DATASET_PATHS
        ):
            dataset_class = MMVUDataset
            args.hf_split = "validation"
            args.hf_subset = None
1756
1757
1758
1759
        elif (
            args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in InstructCoderDataset.SUPPORTED_DATASET_PATHS
        ):
1760
1761
            dataset_class = InstructCoderDataset
            args.hf_split = "train"
1762
1763
1764
1765
        elif (
            args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in MTBenchDataset.SUPPORTED_DATASET_PATHS
        ):
1766
1767
            dataset_class = MTBenchDataset
            args.hf_split = "train"
1768
1769
1770
1771
1772
        elif (
            args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
        ):
            dataset_class = MultiModalConversationDataset
1773
1774
1775
1776
        elif (
            args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in ConversationDataset.SUPPORTED_DATASET_PATHS
        ):
1777
            dataset_class = ConversationDataset
1778
1779
1780
1781
        elif (
            args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in AIMODataset.SUPPORTED_DATASET_PATHS
        ):
1782
1783
            dataset_class = AIMODataset
            args.hf_split = "train"
1784
        elif (
1785
            args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS  # noqa: E501
1786
1787
            or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS
        ):
1788
1789
            dataset_class = NextEditPredictionDataset
            args.hf_split = "train"
1790
1791
1792
1793
        elif (
            args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in ASRDataset.SUPPORTED_DATASET_PATHS
        ):
1794
1795
            dataset_class = ASRDataset
            args.hf_split = "train"
1796
1797
1798
1799
1800
1801
1802
        elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
            dataset_class = BlazeditDataset
            args.hf_split = "train"
            hf_kwargs = {
                "min_distance": args.blazedit_min_distance,
                "max_distance": args.blazedit_max_distance,
            }
1803
1804
1805
1806
        elif (
            args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
        ):
1807
1808
            dataset_class = MLPerfDataset
            args.hf_split = "train"
1809
1810
1811
1812
1813
1814
1815
        elif (
            args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
            or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
        ):
            dataset_class = MMStarDataset
            args.hf_split = "val"
            args.hf_subset = None
1816
        else:
1817
1818
1819
1820
1821
1822
1823
            supported_datasets = set(
                [
                    dataset_name
                    for cls in HuggingFaceDataset.__subclasses__()
                    for dataset_name in cls.SUPPORTED_DATASET_PATHS
                ]
            )
1824
1825
1826
1827
1828
            raise ValueError(
                f"Unsupported dataset path: {args.dataset_path}. "
                "Huggingface dataset only supports dataset_path"
                f" from one of following: {supported_datasets}. "
                "Please consider contributing if you would "
1829
1830
                "like to add support for additional dataset formats."
            )
1831

1832
1833
        if dataset_class.IS_MULTIMODAL and not (
            args.backend in ("openai-chat", "openai-audio")
1834
            or "embeddings-" in args.backend
1835
        ):
1836
1837
            # multi-modal benchmark is only available on OpenAI Chat
            # endpoint-type.
1838
1839
            raise ValueError(
                "Multi-modal content is only supported on 'openai-chat' and "
1840
1841
                "'openai-audio' backends."
            )
1842
1843
1844
1845
1846
        input_requests = dataset_class(
            dataset_path=args.dataset_path,
            dataset_subset=args.hf_subset,
            dataset_split=args.hf_split,
            random_seed=args.seed,
1847
            no_stream=args.no_stream,
1848
            hf_name=args.hf_name,
1849
            disable_shuffle=args.disable_shuffle,
1850
1851
1852
1853
        ).sample(
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            output_len=args.hf_output_len,
1854
            enable_multimodal_chat=args.enable_multimodal_chat,
1855
            request_id_prefix=args.request_id_prefix,
1856
            no_oversample=args.no_oversample,
1857
            skip_chat_template=args.skip_chat_template,
1858
            **hf_kwargs,
1859
1860
1861
1862
1863
        )

    else:
        # For datasets that follow a similar structure, use a mapping.
        dataset_mapping = {
1864
            "spec_bench": lambda: SpecBench(
1865
1866
1867
                dataset_path=args.dataset_path,
                category=args.spec_bench_category,
                disable_shuffle=args.disable_shuffle,
1868
            ).sample(
1869
1870
1871
                num_requests=args.num_prompts,
                tokenizer=tokenizer,
                output_len=args.spec_bench_output_len,
1872
                enable_multimodal_chat=args.enable_multimodal_chat,
1873
                request_id_prefix=args.request_id_prefix,
1874
                no_oversample=args.no_oversample,
1875
            ),
1876
            "sharegpt": lambda: ShareGPTDataset(
1877
1878
1879
                random_seed=args.seed,
                dataset_path=args.dataset_path,
                disable_shuffle=args.disable_shuffle,
1880
1881
1882
1883
            ).sample(
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                output_len=args.sharegpt_output_len,
1884
                enable_multimodal_chat=args.enable_multimodal_chat,
1885
                request_id_prefix=args.request_id_prefix,
1886
                no_oversample=args.no_oversample,
1887
1888
            ),
            "burstgpt": lambda: BurstGPTDataset(
1889
1890
1891
                random_seed=args.seed,
                dataset_path=args.dataset_path,
                disable_shuffle=args.disable_shuffle,
1892
1893
1894
1895
            ).sample(
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                request_id_prefix=args.request_id_prefix,
1896
                no_oversample=args.no_oversample,
1897
1898
            ),
            "random": lambda: RandomDataset(
1899
1900
1901
                random_seed=args.seed,
                dataset_path=args.dataset_path,
                disable_shuffle=args.disable_shuffle,
1902
            ).sample(
1903
1904
1905
1906
1907
1908
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                prefix_len=args.random_prefix_len,
                input_len=args.random_input_len,
                output_len=args.random_output_len,
                range_ratio=args.random_range_ratio,
1909
                request_id_prefix=args.request_id_prefix,
1910
                batchsize=args.random_batch_size,
1911
                no_oversample=args.no_oversample,
1912
            ),
1913
            "random-mm": lambda: RandomMultiModalDataset(
1914
1915
1916
                random_seed=args.seed,
                dataset_path=args.dataset_path,
                disable_shuffle=args.disable_shuffle,
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
            ).sample(
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                prefix_len=args.random_prefix_len,
                range_ratio=args.random_range_ratio,
                input_len=args.random_input_len,
                output_len=args.random_output_len,
                base_items_per_request=args.random_mm_base_items_per_request,
                limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
                num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
                bucket_config=args.random_mm_bucket_config,
1928
                enable_multimodal_chat=args.enable_multimodal_chat,
1929
                request_id_prefix=args.request_id_prefix,
1930
                no_oversample=args.no_oversample,
1931
            ),
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
            "random-rerank": lambda: RandomDatasetForReranking(
                random_seed=args.seed,
                dataset_path=args.dataset_path,
                disable_shuffle=args.disable_shuffle,
            ).sample(
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                input_len=args.random_input_len,
                range_ratio=args.random_range_ratio,
                request_id_prefix=args.request_id_prefix,
                batchsize=args.random_batch_size,
                is_reranker=not args.no_reranker,
            ),
1945
            "prefix_repetition": lambda: PrefixRepetitionRandomDataset(
1946
1947
1948
                random_seed=args.seed,
                dataset_path=args.dataset_path,
                disable_shuffle=args.disable_shuffle,
1949
1950
1951
1952
1953
1954
1955
            ).sample(
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                prefix_len=args.prefix_repetition_prefix_len,
                suffix_len=args.prefix_repetition_suffix_len,
                num_prefixes=args.prefix_repetition_num_prefixes,
                output_len=args.prefix_repetition_output_len,
1956
                request_id_prefix=args.request_id_prefix,
1957
                no_oversample=args.no_oversample,
1958
            ),
1959
1960
1961
        }

        try:
1962
            # Enforce endpoint compatibility for multimodal datasets.
1963
            if args.dataset_name == "random-mm" and args.backend not in ["openai-chat"]:
1964
1965
1966
1967
                raise ValueError(
                    "Multi-modal content (images) is only supported on "
                    "'openai-chat' backend."
                )
1968
1969
1970
1971
1972
1973
1974
            input_requests = dataset_mapping[args.dataset_name]()
        except KeyError as err:
            raise ValueError(f"Unknown dataset: {args.dataset_name}") from err

    return input_requests


1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
# -----------------------------------------------------------------------------
# Custom Dataset Implementation
# -----------------------------------------------------------------------------


class CustomDataset(BenchmarkDataset):
    """
    Implements the Custom dataset.  Loads data from a JSONL file and generates
    sample requests based on conversation turns. E.g.,
    ```
1985
1986
1987
    {"prompt": "What is the capital of India?", "output_tokens": 10}
    {"prompt": "What is the capital of Iran?", "output_tokens": 1520}
    {"prompt": "What is the capital of China?", "output_tokens": 819}
1988
    ```
1989
1990
    Note that 'output_tokens' column is optional and has to be provided only if
    'custom-output-len' argument is None or -1.
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        # self.data will be a list of dictionaries
        # e.g., [{"prompt": "What is the capital of India?"}, ...]
        # This will be the standardized format which load_data()
        # has to convert into depending on the filetype of dataset_path.
        # sample() will assume this standardized format of self.data
        self.data = []

        # Load the JSONL file
        if self.dataset_path.endswith(".jsonl"):
2010
            jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023

            # check if the JSONL file has a 'prompt' column
            if "prompt" not in jsonl_data.columns:
                raise ValueError("JSONL file must contain a 'prompt' column.")

            # Convert each row to a dictionary and append to self.data
            # This will convert the DataFrame to a list of dictionaries
            # where each dictionary corresponds to a row in the DataFrame.
            # This is the standardized format we want for self.data
            for _, row in jsonl_data.iterrows():
                self.data.append(row.to_dict())
        else:
            raise NotImplementedError(
2024
2025
                "Only JSONL format is supported for CustomDataset."
            )
2026
2027

        random.seed(self.random_seed)
2028
2029
        if not getattr(self, "disable_shuffle", False):
            random.shuffle(self.data)
2030
2031
2032

    def sample(
        self,
2033
        tokenizer: TokenizerLike,
2034
        num_requests: int,
2035
2036
2037
        lora_path: str | None = None,
        max_loras: int | None = None,
        output_len: int | None = None,
2038
2039
        enable_multimodal_chat: bool = False,
        skip_chat_template: bool = False,
2040
        request_id_prefix: str = "",
2041
        no_oversample: bool = False,
2042
2043
        **kwargs,
    ) -> list:
2044
2045
2046
2047
        # load all data if needed
        self.num_available_samples = len(self.data)
        if num_requests <= 0:
            num_requests = self.num_available_samples
2048
2049
2050
2051
2052
            logger.info(
                "num_requests is set to 0 or negative, "
                "so using all available samples: %d",
                num_requests,
            )
2053

2054
        sampled_requests = []
2055
        for i, item in enumerate(self.data):
2056
2057
2058
2059
            if len(sampled_requests) >= num_requests:
                break
            prompt = item["prompt"]

2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
            new_output_len = output_len
            if output_len is None or output_len == -1:
                # check that the request has an 'output_tokens' field
                if "output_tokens" not in item:
                    raise ValueError(
                        "If no output length is provided the "
                        "custom dataset must contain an 'output_tokens' field."
                    )
                # Use number of output tokens from the request data
                try:
                    new_output_len = int(item["output_tokens"])
                except (ValueError, TypeError) as e:
                    raise ValueError(
                        f"Invalid value for 'output_tokens' in custom dataset: "
                        f"'{item['output_tokens']}'. Must be an integer."
                    ) from e

2077
2078
2079
            # apply template
            if not skip_chat_template:
                prompt = tokenizer.apply_chat_template(
2080
                    [{"role": "user", "content": prompt}],
2081
2082
2083
2084
2085
2086
2087
2088
2089
                    add_generation_prompt=True,
                    tokenize=False,
                )

            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
2090
                    expected_output_len=new_output_len,
2091
                    request_id=request_id_prefix + str(i),
2092
2093
2094
2095
2096
                )
            )
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2097
2098
2099
2100

        return sampled_requests


2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
class CustomMMDataset(CustomDataset):
    """
    Implements the Custom MultiModal dataset. Loads data from a JSONL file and generates
    sample requests based on conversation turns. E.g.,
    ```
    {
        "prompt": "How many red blocks in the given images?",
        "image_files": ["path/to/image1.png", "path/to/image2.png"],
    }
    {
        "prompt": "Which country has the most pokemons based on the given graphs?",
        "image_files": ["path/to/image.png"],
    }
    ```

    NOTE: Only the first image file in "image_files" is used for each sample request.

    This is used to benchmark multimodal LLMs on arbitrary datasets.
    """

    IS_MULTIMODAL = True

    def sample(
        self,
        tokenizer: TokenizerLike,
        num_requests: int,
        output_len: int | None = None,
        enable_multimodal_chat: bool = False,
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
    ) -> list:
        # load all data if needed
        self.num_available_samples = len(self.data)
        if num_requests <= 0:
            num_requests = self.num_available_samples
            logger.info(
                "num_requests is set to 0 or negative, "
                "so using all available samples: %d",
                num_requests,
            )

        sampled_requests = []
        for i, item in enumerate(self.data):
            if len(sampled_requests) >= num_requests:
                break
            prompt = item["prompt"]

            prompt_len = len(tokenizer(prompt).input_ids)
            images = item["image_files"]
            if len(images) > 1:
                logger.warning(
                    "Multiple image files found for sample %d. "
                    "Only the first image will be used.",
                    i,
                )
            mm_content = process_image(images[0])
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)

            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
                    request_id=request_id_prefix + str(i),
                )
            )
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )

        return sampled_requests


2180
2181
2182
2183
2184
2185
2186
2187
# -----------------------------------------------------------------------------
# Spec Bench Dataset Implementation
# -----------------------------------------------------------------------------


class SpecBench(CustomDataset):
    """
    Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
2188
    Download the dataset using:
2189
    wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
2190
    """  # noqa: E501
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203

    def __init__(self, **kwargs) -> None:
        self.category = kwargs.pop("category", None)
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        self.data = []

        # Load the JSONL file
2204
        jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)
2205
2206
2207
2208
2209
2210
2211

        # check if the JSONL file has a 'turns' column
        if "turns" not in jsonl_data.columns:
            raise ValueError("JSONL file must contain a 'turns' column.")

        for _, row in jsonl_data.iterrows():
            # sample only from a specific category if specified
2212
            if (not self.category) or (self.category == row["category"]):
2213
2214
2215
2216
                prompt = row["turns"][0]
                self.data.append({"prompt": prompt})

        random.seed(self.random_seed)
2217
2218
        if not getattr(self, "disable_shuffle", False):
            random.shuffle(self.data)
2219
2220
2221
2222

    def sample(self, **kwargs) -> list:
        # leverage CustomDataset sample
        return super().sample(**kwargs)
2223
2224


2225
2226
2227
2228
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------

2229

2230
2231
2232
@deprecated(
    "SonnetDataset is deprecated and will be removed in a future version.",
)
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
class SonnetDataset(BenchmarkDataset):
    """
    Simplified implementation of the Sonnet dataset.  Loads poem lines from a
    text file and generates sample requests.  Default values here copied from
    `benchmark_serving.py` for the sonnet dataset.
    """

    DEFAULT_PREFIX_LEN = 200
    DEFAULT_INPUT_LEN = 550
    DEFAULT_OUTPUT_LEN = 150

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if not self.dataset_path:
            raise ValueError("dataset_path must be provided.")
        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = f.readlines()

    def sample(
        self,
2259
        tokenizer: TokenizerLike,
2260
2261
2262
2263
2264
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
        return_prompt_formatted: bool = False,
2265
        request_id_prefix: str = "",
2266
        no_oversample: bool = False,
2267
2268
2269
2270
        **kwargs,
    ) -> list:
        # Calculate average token length for a poem line.
        tokenized_lines = [tokenizer(line).input_ids for line in self.data]
2271
        avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
2272
2273
2274
2275

        # Build the base prompt.
        base_prompt = "Pick as many lines as you can from these poem lines:\n"
        base_msg = [{"role": "user", "content": base_prompt}]
2276
2277
2278
        base_fmt = tokenizer.apply_chat_template(
            base_msg, add_generation_prompt=True, tokenize=False
        )
2279
2280
2281
2282
        base_offset = len(tokenizer(base_fmt).input_ids)
        if input_len <= base_offset:
            raise ValueError(
                f"'input_len' must be higher than the base prompt length "
2283
2284
                f"({base_offset})."
            )
2285
2286
2287
2288
2289
2290
2291

        # Determine how many poem lines to use.
        num_input_lines = round((input_len - base_offset) / avg_len)
        num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
        prefix_lines = self.data[:num_prefix_lines]

        samples = []
2292
        ind = 0
2293
        while len(samples) < num_requests:
2294
2295
2296
            extra_lines = random.choices(
                self.data, k=num_input_lines - num_prefix_lines
            )
2297
2298
2299
            prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
            msg = [{"role": "user", "content": prompt}]
            prompt_formatted = tokenizer.apply_chat_template(
2300
2301
                msg, add_generation_prompt=True, tokenize=False
            )
2302
2303
2304
2305
            prompt_len = len(tokenizer(prompt_formatted).input_ids)
            if prompt_len <= input_len:
                samples.append(
                    SampleRequest(
2306
                        prompt=prompt_formatted if return_prompt_formatted else prompt,
2307
2308
                        prompt_len=prompt_len,
                        expected_output_len=output_len,
2309
2310
2311
                        request_id=request_id_prefix + str(ind),
                    )
                )
2312
                ind += 1
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
        return samples


# -----------------------------------------------------------------------------
# BurstGPT Dataset Implementation
# -----------------------------------------------------------------------------


class BurstGPTDataset(BenchmarkDataset):
    """
    Implements the BurstGPT dataset.  Loads data from a CSV file and generates
    sample requests based on synthetic prompt generation. Only rows with Model
    "GPT-4" and positive response tokens are used.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

2332
2333
2334
    def load_data(
        self,
    ):
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        df = pd.read_csv(self.dataset_path)
        # Filter to keep only GPT-4 rows.
        gpt4_df = df[df["Model"] == "GPT-4"]
        # Remove failed requests (where Response tokens is 0 or less).
        gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
        # Sample the desired number of rows.
        self.data = gpt4_df

    def _sample_loaded_data(self, num_requests: int) -> list:
        if num_requests <= len(self.data):
2348
            data = self.data.sample(n=num_requests, random_state=self.random_seed)
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
        else:
            data = self.data.sample(
                n=num_requests,
                random_state=self.random_seed,
                replace=True,
            )
        # Convert the dataframe to a list of lists.
        return data.values.tolist()

    def sample(
        self,
2360
        tokenizer: TokenizerLike,
2361
        num_requests: int,
2362
2363
        max_loras: int | None = None,
        lora_path: str | None = None,
2364
        request_id_prefix: str = "",
2365
        no_oversample: bool = False,
2366
2367
2368
2369
2370
2371
2372
        **kwargs,
    ) -> list[SampleRequest]:
        samples = []
        data = self._sample_loaded_data(num_requests=num_requests)
        for i in range(num_requests):
            input_len = int(data[i][2])
            output_len = int(data[i][3])
2373
            lora_req = self.get_random_lora_request(
2374
2375
                max_loras=max_loras, lora_path=lora_path
            )
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
            vocab_size = tokenizer.vocab_size
            # Generate a synthetic prompt: a list of token IDs computed as (i +
            # j) modulo vocab_size.
            token_ids = [(i + j) % vocab_size for j in range(input_len)]
            prompt = tokenizer.decode(token_ids)
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=input_len,
                    expected_output_len=output_len,
                    lora_request=lora_req,
2387
                    request_id=request_id_prefix + str(i),
2388
2389
                )
            )
2390
2391
2392
2393
2394
2395
2396
2397
2398
        return samples


# -----------------------------------------------------------------------------
# HuggingFace Dataset Base Implementation
# -----------------------------------------------------------------------------
class HuggingFaceDataset(BenchmarkDataset):
    """Base class for datasets hosted on HuggingFace."""

2399
    SUPPORTED_DATASET_PATHS: set[str] | dict[str, Callable] = set()
2400
2401
2402
2403
2404

    def __init__(
        self,
        dataset_path: str,
        dataset_split: str,
2405
        no_stream: bool = False,
2406
2407
        dataset_subset: str | None = None,
        hf_name: str | None = None,
2408
2409
2410
2411
2412
2413
        **kwargs,
    ) -> None:
        super().__init__(dataset_path=dataset_path, **kwargs)

        self.dataset_split = dataset_split
        self.dataset_subset = dataset_subset
2414
        self.load_stream = not no_stream
2415
        self.hf_name = hf_name or dataset_path
2416
2417
2418
2419
2420
2421
2422
2423
        self.load_data()

    def load_data(self) -> None:
        """Load data from HuggingFace datasets."""
        self.data = load_dataset(
            self.dataset_path,
            name=self.dataset_subset,
            split=self.dataset_split,
2424
            streaming=self.load_stream,
2425
        )
2426
2427
        if not getattr(self, "disable_shuffle", False):
            self.data = self.data.shuffle(seed=self.random_seed)
2428
2429
2430
2431
2432
2433
2434
2435


# -----------------------------------------------------------------------------
# Conversation Dataset Implementation
# -----------------------------------------------------------------------------


class ConversationDataset(HuggingFaceDataset):
2436
    """Dataset for text-only conversation data."""
2437

2438
    SUPPORTED_DATASET_PATHS = {
2439
        "Aeala/ShareGPT_Vicuna_unfiltered",
2440
    }
2441
2442
2443
2444
    IS_MULTIMODAL = False

    def sample(
        self,
2445
        tokenizer: TokenizerLike,
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
        num_requests: int,
        output_len: int | None = None,
        enable_multimodal_chat: bool = False,
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
    ) -> list:
        # Filter examples with at least 2 conversations
        filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
        sampled_requests = []
        ind = 0
        dynamic_output = output_len is None

        for item in filtered_data:
            if len(sampled_requests) >= num_requests:
                break
            conv = item["conversations"]
            prompt, completion = conv[0]["value"], conv[1]["value"]

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
            if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
                continue
            mm_content = process_image(item["image"]) if "image" in item else None
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len and output len
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
                    request_id=request_id_prefix + str(ind),
                )
            )
            ind += 1
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
        return sampled_requests


class MultiModalConversationDataset(HuggingFaceDataset):
    """Dataset for multimodal conversation data."""

    SUPPORTED_DATASET_PATHS = {
        "lmms-lab/LLaVA-OneVision-Data",
    }
2501
    IS_MULTIMODAL = True
2502

2503
2504
    def sample(
        self,
2505
        tokenizer: TokenizerLike,
2506
        num_requests: int,
2507
        output_len: int | None = None,
2508
2509
2510
2511
2512
        enable_multimodal_chat: bool = False,
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
    ) -> list:
2513
        # Filter examples with at least 2 conversations
2514
        filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
2515
        sampled_requests = []
2516
        ind = 0
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
        dynamic_output = output_len is None

        for item in filtered_data:
            if len(sampled_requests) >= num_requests:
                break
            conv = item["conversations"]
            prompt, completion = conv[0]["value"], conv[1]["value"]

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
2531
            if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
2532
                continue
2533
            mm_content = process_image(item["image"]) if "image" in item else None
2534
2535
2536
2537
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len and output len
2538
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
2539
2540
2541
2542
2543
2544
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
2545
                    request_id=request_id_prefix + str(ind),
2546
2547
                )
            )
2548
            ind += 1
2549
2550
2551
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
        return sampled_requests


# -----------------------------------------------------------------------------
# Vision Arena Dataset Implementation
# -----------------------------------------------------------------------------


class VisionArenaDataset(HuggingFaceDataset):
    """
    Vision Arena Dataset.
    """

    DEFAULT_OUTPUT_LEN = 128
    SUPPORTED_DATASET_PATHS = {
2567
2568
        "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
        "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
2569
    }
2570
    IS_MULTIMODAL = True
2571
2572
2573

    def sample(
        self,
2574
        tokenizer: TokenizerLike,
2575
        num_requests: int,
2576
        output_len: int | None = None,
2577
        enable_multimodal_chat: bool = False,
2578
        request_id_prefix: str = "",
2579
        no_oversample: bool = False,
2580
2581
        **kwargs,
    ) -> list:
2582
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
2583
        sampled_requests = []
2584
        for i, item in enumerate(self.data):
2585
2586
            if len(sampled_requests) >= num_requests:
                break
2587
            parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name)
2588
            if parser_fn is None:
2589
                raise ValueError(f"Unsupported dataset path: {self.hf_name}")
2590
2591
2592
2593
2594
2595
2596
            prompt = parser_fn(item)
            mm_content = process_image(item["images"][0])
            prompt_len = len(tokenizer(prompt).input_ids)
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len
2597
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
2598
2599
2600
2601
2602
2603
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
2604
                    request_id=request_id_prefix + str(i),
2605
2606
2607
2608
2609
                )
            )
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2610
2611
2612
        return sampled_requests


2613
2614
2615
2616
2617
2618
2619
2620
class MMVUDataset(HuggingFaceDataset):
    """
    MMVU Dataset.
    https://huggingface.co/datasets/yale-nlp/MMVU
    """

    DEFAULT_OUTPUT_LEN = 128
    SUPPORTED_DATASET_PATHS = {
2621
2622
2623
        "yale-nlp/MMVU": lambda x: x["question"]
        + " "
        + (" ".join(f"{k}.{v}" for k, v in x["choices"].items())),
2624
2625
2626
2627
    }

    def sample(
        self,
2628
        tokenizer: TokenizerLike,
2629
        num_requests: int,
2630
        output_len: int | None = None,
2631
2632
2633
2634
2635
        enable_multimodal_chat: bool = False,
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
    ) -> list:
2636
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
        sampled_requests = []
        for i, item in enumerate(self.data):
            if len(sampled_requests) >= num_requests:
                break
            parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name)
            if parser_fn is None:
                raise ValueError(f"Unsupported dataset path: {self.hf_name}")
            prompt = parser_fn(item)
            mm_content = process_video(item["video"])
            prompt_len = len(tokenizer(prompt).input_ids)
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len
2651
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
2652
2653
2654
2655
2656
2657
2658
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
                    request_id=request_id_prefix + str(i),
2659
2660
2661
2662
2663
                )
            )
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2664
2665
2666
        return sampled_requests


2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------


class InstructCoderDataset(HuggingFaceDataset):
    """
    InstructCoder Dataset.
    https://huggingface.co/datasets/likaixin/InstructCoder

    InstructCoder is the dataset designed for general code editing.  It consists
    of 114,239 instruction-input-output triplets, and covers multiple distinct
    code editing scenario.
    """

    DEFAULT_OUTPUT_LEN = 200  # this is the average default output length
    SUPPORTED_DATASET_PATHS = {
        "likaixin/InstructCoder",
    }

2687
2688
    def sample(
        self,
2689
        tokenizer: TokenizerLike,
2690
        num_requests: int,
2691
        output_len: int | None = None,
2692
2693
2694
2695
2696
        enable_multimodal_chat: bool = False,
        skip_chat_template: bool = False,
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
2697
    ) -> list[SampleRequest]:
2698
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
2699
        sampled_requests = []
2700
        for i, prompt in enumerate(self.sample_prompts(n=num_requests)):
2701
            # apply template
2702
2703
            if not skip_chat_template:
                prompt = tokenizer.apply_chat_template(
2704
                    [{"role": "user", "content": prompt}],
2705
2706
2707
                    add_generation_prompt=True,
                    tokenize=False,
                )
2708

2709
2710
2711
2712
2713
2714
            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
2715
                    request_id=request_id_prefix + str(i),
2716
2717
2718
2719
2720
                )
            )
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2721
2722
        return sampled_requests

2723
2724
2725
2726
2727
2728
2729
2730
    def sample_prompts(self, n: int) -> Iterator[str]:
        for item in self.data.take(n):
            prompt = (
                f"{item['input']}\n\n{item['instruction']} Just output "
                "the code, do not include any explanation."
            )
            yield prompt

2731

2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------


class MTBenchDataset(HuggingFaceDataset):
    """
    MT-Bench Dataset.
    https://huggingface.co/datasets/philschmid/mt-bench

    We create a single turn dataset for MT-Bench.
    This is similar to Spec decoding benchmark setup in vLLM
    https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
    """  # noqa: E501

    DEFAULT_OUTPUT_LEN = 256  # avg len used in SD bench in vLLM
    SUPPORTED_DATASET_PATHS = {
        "philschmid/mt-bench",
    }

    def sample(
        self,
2754
        tokenizer: TokenizerLike,
2755
        num_requests: int,
2756
        output_len: int | None = None,
2757
        enable_multimodal_chat: bool = False,
2758
        skip_chat_template: bool = False,
2759
        request_id_prefix: str = "",
2760
        no_oversample: bool = False,
2761
2762
        **kwargs,
    ) -> list:
2763
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
2764
2765
        sampled_requests = []

2766
        for i, item in enumerate(self.data):
2767
2768
2769
2770
2771
            if len(sampled_requests) >= num_requests:
                break
            prompt = item["turns"][0]

            # apply template
2772
2773
            if not skip_chat_template:
                prompt = tokenizer.apply_chat_template(
2774
                    [{"role": "user", "content": prompt}],
2775
2776
2777
                    add_generation_prompt=True,
                    tokenize=False,
                )
2778
2779
2780
2781
2782
2783
2784

            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
2785
                    request_id=request_id_prefix + str(i),
2786
2787
2788
2789
2790
                )
            )
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2791
2792
2793
        return sampled_requests


2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
# -----------------------------------------------------------------------------
# Blazedit Dataset Implementation
# -----------------------------------------------------------------------------


class BlazeditDataset(HuggingFaceDataset):
    """
    Blazedit Dataset.
    https://github.com/ise-uiuc/blazedit

    5k char version: vdaita/edit_5k_char
    10k char version: vdaita/edit_10k_char
    """  # noqa: E501

    # 5k char version will have output as ~5k chars
    # 10k char version will have output as ~10k chars
    # Assuming 3 char per token, 10k chars will be 3333 tokens
    # We set default to 4000 to be safe
    DEFAULT_OUTPUT_LEN = 4000
    SUPPORTED_DATASET_PATHS = {
        "vdaita/edit_5k_char",
        "vdaita/edit_10k_char",
    }

    def sample(
        self,
2820
        tokenizer: TokenizerLike,
2821
        num_requests: int,
2822
        output_len: int | None = None,
2823
        skip_chat_template: bool = False,
2824
        request_id_prefix: str = "",
2825
        no_oversample: bool = False,
2826
2827
2828
2829
        min_distance: float = 0.0,
        max_distance: float = 1.0,
        **kwargs,
    ) -> list:
2830
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
        sampled_requests = []

        for i, item in enumerate(self.data):
            if len(sampled_requests) >= num_requests:
                break
            code = item["code"]
            change_request = item["change_request"]
            norm_distance = item["norm_distance"]

            # compare the levenshtein distance normalized by code length
            if norm_distance < min_distance or norm_distance > max_distance:
                continue
2843
2844

            # template copied from
2845
            # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
2846
            prompt = f"""Given a code file, please apply the change requests and generate the new file.
2847
2848
2849
2850
2851
2852
2853
2854
2855

Original file:
```python
{code}
```

Change request:
{change_request}

2856
Please generate the new code file in the "New file" section below."""  # noqa: E501
2857
2858

            # apply template
2859
2860
            if not skip_chat_template:
                prompt = tokenizer.apply_chat_template(
2861
                    [{"role": "user", "content": prompt}],
2862
2863
2864
                    add_generation_prompt=True,
                    tokenize=False,
                )
2865
2866
2867
2868
2869
2870
2871
2872
2873

            prompt_len = len(tokenizer(prompt).input_ids)

            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    request_id=request_id_prefix + str(i),
2874
2875
2876
2877
2878
                )
            )
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2879

2880
2881
2882
        return sampled_requests


2883
2884
2885
2886
2887
2888
2889
2890
2891
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------


class AIMODataset(HuggingFaceDataset):
    """
    Dataset class for processing a AIMO dataset with reasoning questions.
    """
2892

2893
    SUPPORTED_DATASET_PATHS = {
2894
2895
2896
        "AI-MO/aimo-validation-aime",
        "AI-MO/NuminaMath-1.5",
        "AI-MO/NuminaMath-CoT",
2897
2898
    }

2899
2900
    def sample(
        self,
2901
        tokenizer: TokenizerLike,
2902
        num_requests: int,
2903
        output_len: int | None = None,
2904
2905
2906
2907
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
    ) -> list:
2908
        sampled_requests = []
2909
        ind = 0
2910
2911
2912
2913
2914
        dynamic_output = output_len is None

        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
2915
            prompt, completion = item["problem"], item["solution"]
2916
2917
2918
2919
2920
2921
2922

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
2923
2924
2925
            if dynamic_output and not is_valid_sequence(
                prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
            ):
2926
2927
2928
2929
2930
2931
2932
                continue
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=None,
2933
                    request_id=request_id_prefix + str(ind),
2934
2935
                )
            )
2936
            ind += 1
2937
2938
2939
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
2940
        return sampled_requests
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960


# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------


zeta_prompt = """### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.

### User Edits:

{}

### User Excerpt:

{}

### Response:

2961
"""  # noqa: E501
2962
2963
2964


def _format_zeta_prompt(
2965
2966
    sample: dict, original_start_marker: str = "<|editable_region_start|>"
) -> dict:
2967
    """Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
2968
2969
2970

    This function formats examples from the NEP dataset
    into prompts and expected outputs. It could be
2971
    further extended to support more NEP datasets.
2972

2973
    Args:
2974
        sample: The dataset sample containing events,
2975
            inputs, and outputs.
2976
2977
        original_start_marker: The marker indicating the
            start of the editable region. Defaults to
2978
            "<|editable_region_start|>".
2979

2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
    Returns:
        A dictionary with the formatted prompts and expected outputs.
    """
    events = sample["events"]
    input = sample["input"]
    output = sample["output"]
    prompt = zeta_prompt.format(events, input)

    # following the original implementation, extract the focused region
    # from the raw output
    output_start_index = output.find(original_start_marker)
    output_focused_region = output[output_start_index:]
    expected_output = output_focused_region

    return {"prompt": prompt, "expected_output": expected_output}


class NextEditPredictionDataset(HuggingFaceDataset):
    """
    Dataset class for processing a Next Edit Prediction dataset.
    """

    SUPPORTED_DATASET_PATHS = {
        "zed-industries/zeta",
    }
    MAPPING_PROMPT_FUNCS = {
        "zed-industries/zeta": _format_zeta_prompt,
    }

3009
3010
    def sample(
        self,
3011
        tokenizer: TokenizerLike,
3012
3013
3014
3015
3016
        num_requests: int,
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
    ):
3017
        formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name)
3018
        if formatting_prompt_func is None:
3019
            raise ValueError(f"Unsupported dataset path: {self.hf_name}")
3020
        samples = []
3021
        for i, sample in enumerate(self.data):
3022
3023
3024
3025
3026
3027
            sample = formatting_prompt_func(sample)
            samples.append(
                SampleRequest(
                    prompt=sample["prompt"],
                    prompt_len=len(tokenizer(sample["prompt"]).input_ids),
                    expected_output_len=len(
3028
3029
                        tokenizer(sample["expected_output"]).input_ids
                    ),
3030
                    request_id=request_id_prefix + str(i),
3031
3032
                )
            )
3033
3034
            if len(samples) >= num_requests:
                break
3035
3036
3037
        self.maybe_oversample_requests(
            samples, num_requests, request_id_prefix, no_oversample
        )
3038
        return samples
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077


# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------


class ASRDataset(HuggingFaceDataset):
    """
    Dataset class for processing a ASR dataset for transcription.
    Tested on the following set:

    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | Dataset        | Domain                                 | Speaking Style           | hf-subset                   |
    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | TED-LIUM       | TED talks                              | Oratory                  | release1, release2, release3|
    |                |                                        |                          | release3-speaker-adaptation |
    | VoxPopuli      | European Parliament                    | Oratory                  | en, de, it, fr,  ...        |
    | LibriSpeech    | Audiobook                              | Narrated                 | "LIUM/tedlium"              |
    | GigaSpeech     | Audiobook, podcast, YouTube            | Narrated, spontaneous    | xs, s, m, l, xl, dev, test  |
    | SPGISpeech     | Financial meetings                     | Oratory, spontaneous     | S, M, L, dev, test          |
    | AMI            | Meetings                               | Spontaneous              | ihm, sdm                    |
    +----------------+----------------------------------------+--------------------------+-----------------------------+

    """  # noqa: E501

    SUPPORTED_DATASET_PATHS = {
        "openslr/librispeech_asr",
        "facebook/voxpopuli",
        "LIUM/tedlium",
        "edinburghcstr/ami",
        "speechcolab/gigaspeech",
        "kensho/spgispeech",
    }

    DEFAULT_OUTPUT_LEN = 128
    IS_MULTIMODAL = True

    # TODO Whisper-specific. Abstract interface when more models are supported.
3078
    TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
3079
3080
3081
3082
    skip_long_audios: bool = True

    def sample(
        self,
3083
        tokenizer: TokenizerLike,
3084
        num_requests: int,
3085
        output_len: int | None = None,
3086
        request_id_prefix: str = "",
3087
        no_oversample: bool = False,
3088
3089
        **kwargs,
    ) -> list:
3090
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
3091
3092
3093
        prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
        prompt_len = len(tokenizer(prompt).input_ids)
        sampled_requests = []
3094
        ind = 0
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
        skipped = 0
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
            audio = item["audio"]
            y, sr = audio["array"], audio["sampling_rate"]
            duration_s = librosa.get_duration(y=y, sr=sr)
            # Whisper max supported duration
            if self.skip_long_audios and duration_s > 30:
                skipped += 1
                continue

            mm_content = {"audio": (y, sr)}
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
3114
                    request_id=request_id_prefix + str(ind),
3115
3116
                )
            )
3117
            ind += 1
3118
3119
3120
3121
3122
3123
3124
        if skipped:
            logger.warning(
                "%d samples discarded from dataset due to"
                " their length being greater than"
                " what Whisper supports.",
                skipped,
            )
3125
3126
3127
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
3128
        return sampled_requests
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160


# -----------------------------------------------------------------------------
# MLPerf Dataset Implementation
# -----------------------------------------------------------------------------


class MLPerfDataset(HuggingFaceDataset):
    """
    MLPerf Inference Dataset.

    Dataset on HF:
    https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data
    https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data

    Each record contains:
      - "system_prompt": system role instruction.
      - "question": user question.
      - "output": reference answer.

    We combine the system prompt and question into a chat-formatted prompt
    (using the tokenizer's chat template) and set the expected output length to
    the tokenized length of the provided reference answer.
    """

    SUPPORTED_DATASET_PATHS = {
        "mgoin/mlperf-inference-llama2-data",
        "mgoin/mlperf-inference-llama3.1-data",
    }

    def sample(
        self,
3161
        tokenizer: TokenizerLike,
3162
        num_requests: int,
3163
        output_len: int | None = None,
3164
        request_id_prefix: str = "",
3165
        no_oversample: bool = False,
3166
3167
3168
3169
3170
        **kwargs,
    ) -> list[SampleRequest]:
        # Force dynamic output length based on reference completion.
        dynamic_output = output_len is None
        sampled_requests: list[SampleRequest] = []
3171
        ind = 0
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205

        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break

            system_prompt = item["system_prompt"]
            question = item["question"]
            reference_answer = item["output"]

            # Build chat-style prompt using tokenizer template, if available.
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": question},
            ]
            prompt_formatted = tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=False
            )
            prompt_len = len(tokenizer(prompt_formatted).input_ids)

            # Determine output length from reference answer tokens.
            ref_out_len = len(
                tokenizer(reference_answer, add_special_tokens=False).input_ids
            )
            expected_output_len = ref_out_len if dynamic_output else output_len

            # Validate sequence lengths.
            if not is_valid_sequence(prompt_len, expected_output_len):
                continue

            sampled_requests.append(
                SampleRequest(
                    prompt=prompt_formatted,
                    prompt_len=prompt_len,
                    expected_output_len=expected_output_len,
3206
                    request_id=request_id_prefix + str(ind),
3207
3208
                )
            )
3209
            ind += 1
3210

3211
3212
3213
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
3214
        return sampled_requests
3215
3216
3217
3218
3219
3220
3221
3222


# -----------------------------------------------------------------------------
# Prefix Repetition Dataset Implementation
# -----------------------------------------------------------------------------


class PrefixRepetitionRandomDataset(BenchmarkDataset):
3223
    # Default values copied from benchmark_serving.py for the repeated prefix
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
    # dataset.
    DEFAULT_PREFIX_LEN = 256
    DEFAULT_SUFFIX_LEN = 256
    DEFAULT_NUM_PREFIXES = 10
    DEFAULT_OUTPUT_LEN = 128

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        random.seed(self.random_seed)
        np.random.seed(self.random_seed)

    def sample(
        self,
3240
        tokenizer: TokenizerLike,
3241
3242
3243
3244
3245
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        suffix_len: int = DEFAULT_SUFFIX_LEN,
        num_prefixes: int = DEFAULT_NUM_PREFIXES,
        output_len: int = DEFAULT_OUTPUT_LEN,
3246
        request_id_prefix: str = "",
3247
        no_oversample: bool = False,
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
        **kwargs,
    ) -> list[SampleRequest]:
        vocab_size = tokenizer.vocab_size
        prompts_per_prefix = num_requests // num_prefixes
        if prompts_per_prefix == 0:
            raise ValueError(
                f"num_requests ({num_requests}) must be greater than or equal "
                f"to num_prefixes ({num_prefixes})"
            )

        def _generate_exact_length_tokens(target_length: int) -> list[int]:
            """Generate tokens that decode and re-encode to exactly
            target_length."""
            # Generate random tokens
3262
            tokens = np.random.randint(0, vocab_size, size=target_length).tolist()
3263

3264
            _, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len(  # noqa: E501
3265
3266
3267
3268
3269
3270
                tokenizer=tokenizer,
                token_sequence=tokens,
                target_token_len=target_length,
                add_special_tokens=False,
            )
            return adjusted_tokens, token_mismatch
3271
3272

        requests = []
3273
        token_mismatch_total = 0
3274
        for _ in range(num_prefixes):
3275
3276
            prefix_tokens, prefix_mismatch = _generate_exact_length_tokens(prefix_len)
            token_mismatch_total += prefix_mismatch
3277
3278

            for _ in range(prompts_per_prefix):
3279
                suffix_tokens, suffix_mismatch = _generate_exact_length_tokens(
3280
                    suffix_len
3281
                )
3282
                token_mismatch_total += suffix_mismatch
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
                combined_tokens = prefix_tokens + suffix_tokens
                prompt = tokenizer.decode(combined_tokens)
                prompt_len = len(combined_tokens)
                requests.append(
                    SampleRequest(
                        prompt=prompt,
                        prompt_len=prompt_len,
                        expected_output_len=output_len,
                    )
                )

3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
        if token_mismatch_total != 0:
            sign = "more" if token_mismatch_total > 0 else "fewer"
            logger.warning(
                "Across all generated prompts, there were %d %s tokens "
                "than expected after decoding and re-encoding. This is "
                "expected due to the imperfect nature of the sampling "
                "procedure.",
                abs(token_mismatch_total),
                sign,
            )
3304
3305
        if not getattr(self, "disable_shuffle", False):
            random.shuffle(requests)
3306
        return requests
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318


# -----------------------------------------------------------------------------
# MMStar Dataset Implementation
# -----------------------------------------------------------------------------


class MMStarDataset(HuggingFaceDataset):
    """
    Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
    refer to: https://github.com/sgl-project/SpecForge/pull/106
    """
3319

3320
3321
3322
3323
3324
3325
    DEFAULT_OUTPUT_LEN = 128
    SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"}
    IS_MULTIMODAL = True

    def sample(
        self,
3326
        tokenizer: TokenizerLike,
3327
        num_requests: int,
3328
        output_len: int | None = None,
3329
3330
3331
3332
3333
3334
        enable_multimodal_chat: bool = False,
        request_id_prefix: str = "",
        no_oversample: bool = False,
        **kwargs,
    ) -> list[SampleRequest]:
        # If --hf-output-len is not set, use the default output length.
3335
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
        sampled_requests: list[SampleRequest] = []

        for ind, item in enumerate(self.data):
            if len(sampled_requests) >= num_requests:
                break
            # Split the question text from options
            # (keep only the part before "Options:").
            full_q: str = item.get("question", "")
            question_text = full_q.split("Options:", 1)[0].strip()

            # Multimodal image content.
            mm_content = process_image(item["image"])

            # Compute prompt token length (note: this is plain text length
            # if enable_multimodal_chat is False).
            prompt_len = len(tokenizer(question_text).input_ids)

            if enable_multimodal_chat:
                # If multimodal content should be embedded in the chat message,
                # convert to [{"role":"user","content":[...]}]
                prompt = self.apply_multimodal_chat_transformation(
                    question_text, mm_content
                )
                mm_for_request = None  # Already embedded in chat content.
            else:
                # Default: prompt is plain text,
                # image is in mm_content for the bench to assemble.
                prompt = question_text
                mm_for_request = mm_content

            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_for_request,
                    request_id=request_id_prefix + str(ind),
                )
            )

        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix, no_oversample
        )
        return sampled_requests