benchmark_dataset.py 41.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
  - ShareGPT
  - Random (synthetic)
  - Sonnet
  - BurstGPT
  - HuggingFace
  - VisionArena
"""

import base64
import io
import json
18
import logging
19
20
21
22
23
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from functools import cache
24
25
from io import BytesIO
from typing import Any, Callable, Optional, Union
26
27
28
29
30
31
32
33
34
35

import numpy as np
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import PreTrainedTokenizerBase

from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
36
from vllm.multimodal.image import convert_image_mode
37
38
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer

39
40
logger = logging.getLogger(__name__)

41
42
43
44
45
46
47
48
49
50
51
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------


@dataclass
class SampleRequest:
    """
    Represents a single inference request for benchmarking.
    """

52
    prompt: Union[str, Any]
53
54
55
56
57
58
59
60
61
62
63
64
65
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
    lora_request: Optional[LoRARequest] = None


# -----------------------------------------------------------------------------
# Benchmark Dataset Base Class
# -----------------------------------------------------------------------------


class BenchmarkDataset(ABC):
    DEFAULT_SEED = 0
66
    IS_MULTIMODAL = False
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    def __init__(
        self,
        dataset_path: Optional[str] = None,
        random_seed: int = DEFAULT_SEED,
    ) -> None:
        """
        Initialize the BenchmarkDataset with an optional dataset path and random
        seed.  Args:
            dataset_path (Optional[str]): Path to the dataset. If None, it
            indicates that a default or random dataset might be used.
            random_seed (int): Seed value for reproducible shuffling or
            sampling. Defaults to DEFAULT_SEED.
        """
        self.dataset_path = dataset_path
        # Set the random seed, ensuring that a None value is replaced with the
        # default seed.
84
        self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
85
86
        self.data = None

87
    def apply_multimodal_chat_transformation(
88
89
        self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
    ) -> list[dict]:
90
91
        """
        Transform a prompt and optional multimodal content into a chat format.
92
93
        This method is used for chat models that expect a specific conversation
        format.
94
95
96
97
98
99
        """
        content = [{"text": prompt, "type": "text"}]
        if mm_content is not None:
            content.append(mm_content)
        return [{"role": "user", "content": content}]

100
101
102
    def load_data(self) -> None:
        """
        Load data from the dataset path into self.data.
103

104
105
        This method must be overridden by subclasses since the method to load
        data will vary depending on the dataset format and source.
106

107
108
109
110
        Raises:
            NotImplementedError: If a subclass does not implement this method.
        """
        # TODO (jenniferzhao): add support for downloading data
111
        raise NotImplementedError("load_data must be implemented in subclasses.")
112
113
114
115
116
117
118
119
120
121

    def get_random_lora_request(
        self,
        tokenizer: PreTrainedTokenizerBase,
        max_loras: Optional[int] = None,
        lora_path: Optional[str] = None,
    ) -> tuple[Optional[LoRARequest], AnyTokenizer]:
        """
        Optionally select a random LoRA request and return its associated
        tokenizer.
122

123
124
125
        This method is used when LoRA parameters are provided.  It randomly
        selects a LoRA based on max_loras and retrieves a cached tokenizer for
        that LoRA if available. Otherwise, it returns the base tokenizer.
126

127
128
129
130
131
132
        Args:
            tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
            LoRA is selected.  max_loras (Optional[int]): The maximum number of
            LoRAs available. If None, LoRA is not used.  lora_path
            (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
            is not used.
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        Returns:
            tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
            element is a LoRARequest (or None if not applicable) and the second
            element is the tokenizer associated with the LoRA request (or the
            base tokenizer).
        """
        if max_loras is None or lora_path is None:
            return None, tokenizer

        # Generate a random LoRA ID in the range [1, max_loras].
        lora_id = random.randint(1, max_loras)
        lora_request = LoRARequest(
            lora_name=str(lora_id),
            lora_int_id=lora_id,
            lora_path=lora_path_on_disk(lora_path),
        )
        if lora_id not in lora_tokenizer_cache:
            lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
        # Return lora_request and the cached tokenizer if available; otherwise,
        # return the base tokenizer
        return lora_request, lora_tokenizer_cache[lora_id] or tokenizer

    @abstractmethod
157
158
159
    def sample(
        self, tokenizer: PreTrainedTokenizerBase, num_requests: int
    ) -> list[SampleRequest]:
160
161
        """
        Abstract method to generate sample requests from the dataset.
162

163
164
        Subclasses must override this method to implement dataset-specific logic
        for generating a list of SampleRequest objects.
165

166
167
168
169
        Args:
            tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
             for processing the dataset's text.
            num_requests (int): The number of sample requests to generate.
170

171
172
173
174
175
176
        Returns:
            list[SampleRequest]: A list of sample requests generated from the
            dataset.
        """
        raise NotImplementedError("sample must be implemented in subclasses.")

177
178
179
    def maybe_oversample_requests(
        self, requests: list[SampleRequest], num_requests: int
    ) -> None:
180
181
182
183
184
185
186
187
188
189
        """
        Oversamples the list of requests if its size is less than the desired
        number.

        Args:
            requests (List[SampleRequest]): The current list of sampled
            requests.  num_requests (int): The target number of requests.
        """
        if len(requests) < num_requests:
            random.seed(self.random_seed)
190
            additional = random.choices(requests, k=num_requests - len(requests))
191
            requests.extend(additional)
192
            logger.info("Oversampled requests to reach %d total samples.", num_requests)
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

# -----------------------------------------------------------------------------
# Utility Functions and Global Caches
# -----------------------------------------------------------------------------


def is_valid_sequence(
    prompt_len: int,
    output_len: int,
    min_len: int = 4,
    max_prompt_len: int = 1024,
    max_total_len: int = 2048,
    skip_min_output_len_check: bool = False,
) -> bool:
    """
    Validate a sequence based on prompt and output lengths.

    Default pruning criteria are copied from the original `sample_hf_requests`
    and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
    from `sample_requests` in benchmark_throughput.py.
    """
    # Check for invalid conditions
    prompt_too_short = prompt_len < min_len
217
    output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
218
219
220
221
    prompt_too_long = prompt_len > max_prompt_len
    combined_too_long = (prompt_len + output_len) > max_total_len

    # Return True if none of the invalid conditions are met
222
223
224
    return not (
        prompt_too_short or output_too_short or prompt_too_long or combined_too_long
    )
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239


@cache
def lora_path_on_disk(lora_path: str) -> str:
    return get_adapter_absolute_path(lora_path)


# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}


def process_image(image: Any) -> Mapping[str, Any]:
    """
    Process a single image input and return a multimedia content dictionary.

240
    Supports three input types:
241

242
243
244
245
246
247
248
249
250
251
    1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
       containing raw image data.  - Loads the bytes as a PIL.Image.Image.

    2. PIL.Image.Image input: - Converts the image to RGB.  - Saves the image as
       a JPEG in memory.  - Encodes the JPEG data as a base64 string.  - Returns
       a dictionary with the image as a base64 data URL.

    3. String input: - Treats the string as a URL or local file path.  -
       Prepends "file://" if the string doesn't start with "http://" or
       "file://".  - Returns a dictionary with the image URL.
252
253

    Raises:
254
        ValueError: If the input is not a supported type.
255
    """
256
257
    if isinstance(image, dict) and "bytes" in image:
        image = Image.open(BytesIO(image["bytes"]))
258
    if isinstance(image, Image.Image):
259
        image = convert_image_mode(image, "RGB")
260
261
        with io.BytesIO() as image_data:
            image.save(image_data, format="JPEG")
262
            image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
263
264
        return {
            "type": "image_url",
265
            "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
266
267
268
        }

    if isinstance(image, str):
269
270
271
        image_url = (
            image if image.startswith(("http://", "file://")) else f"file://{image}"
        )
272
273
        return {"type": "image_url", "image_url": {"url": image_url}}

274
275
276
277
    raise ValueError(
        f"Invalid image input {image}. Must be a PIL.Image.Image"
        " or str or dictionary with raw image bytes."
    )
278
279
280
281
282
283
284
285
286
287


# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------


class RandomDataset(BenchmarkDataset):
    # Default values copied from benchmark_serving.py for the random dataset.
    DEFAULT_PREFIX_LEN = 0
288
    DEFAULT_RANGE_RATIO = 0.0
289
290
291
292
293
294
295
296
297
    DEFAULT_INPUT_LEN = 1024
    DEFAULT_OUTPUT_LEN = 128

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

298
299
300
301
302
303
304
305
306
307
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        range_ratio: float = DEFAULT_RANGE_RATIO,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
        **kwargs,
    ) -> list[SampleRequest]:
308
309
310
311
312
        # Enforce range_ratio < 1
        assert range_ratio < 1.0, (
            "random_range_ratio must be < 1.0 to ensure a valid sampling range"
        )

313
        vocab_size = tokenizer.vocab_size
314
315
        num_special_tokens = tokenizer.num_special_tokens_to_add()
        real_input_len = input_len - num_special_tokens
316

317
318
319
320
321
        prefix_token_ids = (
            np.random.randint(0, vocab_size, size=prefix_len).tolist()
            if prefix_len > 0
            else []
        )
322

323
        # New sampling logic: [X * (1 - b), X * (1 + b)]
324
325
        input_low = int(real_input_len * (1 - range_ratio))
        input_high = int(real_input_len * (1 + range_ratio))
326
327
328
329
330
        output_low = int(output_len * (1 - range_ratio))
        output_high = int(output_len * (1 + range_ratio))

        # Add logging for debugging
        logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
331
332
333
334
        logger.info("Sampling output_len from [%s, %s]", output_low, output_high)

        input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
        output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
335
336
337
338
        offsets = np.random.randint(0, vocab_size, size=num_requests)

        requests = []
        for i in range(num_requests):
339
340
341
            inner_seq = (
                (offsets[i] + i + np.arange(input_lens[i])) % vocab_size
            ).tolist()
342
343
            token_sequence = prefix_token_ids + inner_seq
            prompt = tokenizer.decode(token_sequence)
344
345
346
347
348
349
350
351
            # After decoding the prompt we have to encode and decode it again.
            # This is done because in some cases N consecutive tokens
            # give a string tokenized into != N number of tokens.
            # For example for GPT2Tokenizer:
            # [6880, 6881] -> ['Ġcalls', 'here'] ->
            # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
            # To avoid uncontrolled change of the prompt length,
            # the encoded sequence is truncated before being decode again.
352
            total_input_len = prefix_len + int(input_lens[i])
353
            re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
354
                :total_input_len
355
            ]
356
            prompt = tokenizer.decode(re_encoded_sequence)
357
            total_input_len = len(re_encoded_sequence)
358
359
360
361
362
            requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=total_input_len,
                    expected_output_len=int(output_lens[i]),
363
364
                )
            )
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        return requests


# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
# -----------------------------------------------------------------------------


class ShareGPTDataset(BenchmarkDataset):
    """
    Implements the ShareGPT dataset.  Loads data from a JSON file and generates
    sample requests based on conversation turns.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = json.load(f)
        # Filter entries with at least two conversation turns.
        self.data = [
391
392
            entry
            for entry in self.data
393
394
395
396
397
            if "conversations" in entry and len(entry["conversations"]) >= 2
        ]
        random.seed(self.random_seed)
        random.shuffle(self.data)

398
399
400
401
402
403
404
405
406
407
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        lora_path: Optional[str] = None,
        max_loras: Optional[int] = None,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
408
409
410
411
        samples: list = []
        for entry in self.data:
            if len(samples) >= num_requests:
                break
412
413
414
415
            prompt, completion = (
                entry["conversations"][0]["value"],
                entry["conversations"][1]["value"],
            )
416
417

            lora_request, tokenizer = self.get_random_lora_request(
418
419
                tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
            )
420
421
422
            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
423
424
425
426
427
428
            new_output_len = len(completion_ids) if output_len is None else output_len
            if not is_valid_sequence(
                prompt_len,
                new_output_len,
                skip_min_output_len_check=output_len is not None,
            ):
429
                continue
430
            if enable_multimodal_chat:
431
                prompt = self.apply_multimodal_chat_transformation(prompt, None)
432
433
434
435
436
437
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=new_output_len,
                    lora_request=lora_request,
438
439
                )
            )
440
        self.maybe_oversample_requests(samples, num_requests)
441
442
443
        return samples


444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
# -----------------------------------------------------------------------------
# Custom Dataset Implementation
# -----------------------------------------------------------------------------


class CustomDataset(BenchmarkDataset):
    """
    Implements the Custom dataset.  Loads data from a JSONL file and generates
    sample requests based on conversation turns. E.g.,
    ```
    {"prompt": "What is the capital of India?"}
    {"prompt": "What is the capital of Iran?"}
    {"prompt": "What is the capital of China?"}
    ```
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        # self.data will be a list of dictionaries
        # e.g., [{"prompt": "What is the capital of India?"}, ...]
        # This will be the standardized format which load_data()
        # has to convert into depending on the filetype of dataset_path.
        # sample() will assume this standardized format of self.data
        self.data = []

        # Load the JSONL file
        if self.dataset_path.endswith(".jsonl"):
            jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)

            # check if the JSONL file has a 'prompt' column
            if "prompt" not in jsonl_data.columns:
                raise ValueError("JSONL file must contain a 'prompt' column.")

            # Convert each row to a dictionary and append to self.data
            # This will convert the DataFrame to a list of dictionaries
            # where each dictionary corresponds to a row in the DataFrame.
            # This is the standardized format we want for self.data
            for _, row in jsonl_data.iterrows():
                self.data.append(row.to_dict())
        else:
            raise NotImplementedError(
                "Only JSONL format is supported for CustomDataset."
            )

        random.seed(self.random_seed)
        random.shuffle(self.data)

    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        lora_path: Optional[str] = None,
        max_loras: Optional[int] = None,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        skip_chat_template: bool = False,
        **kwargs,
    ) -> list:
        sampled_requests = []
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
            prompt = item["prompt"]

            # apply template
            if not skip_chat_template:
                prompt = tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    add_generation_prompt=True,
                    tokenize=False,
                )

            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                )
            )
        self.maybe_oversample_requests(sampled_requests, num_requests)

        return sampled_requests


535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------


class SonnetDataset(BenchmarkDataset):
    """
    Simplified implementation of the Sonnet dataset.  Loads poem lines from a
    text file and generates sample requests.  Default values here copied from
    `benchmark_serving.py` for the sonnet dataset.
    """

    DEFAULT_PREFIX_LEN = 200
    DEFAULT_INPUT_LEN = 550
    DEFAULT_OUTPUT_LEN = 150

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if not self.dataset_path:
            raise ValueError("dataset_path must be provided.")
        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = f.readlines()

564
565
566
567
568
569
570
571
572
573
    def sample(
        self,
        tokenizer,
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
        return_prompt_formatted: bool = False,
        **kwargs,
    ) -> list:
574
575
        # Calculate average token length for a poem line.
        tokenized_lines = [tokenizer(line).input_ids for line in self.data]
576
        avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
577
578
579
580

        # Build the base prompt.
        base_prompt = "Pick as many lines as you can from these poem lines:\n"
        base_msg = [{"role": "user", "content": base_prompt}]
581
582
583
        base_fmt = tokenizer.apply_chat_template(
            base_msg, add_generation_prompt=True, tokenize=False
        )
584
585
586
587
        base_offset = len(tokenizer(base_fmt).input_ids)
        if input_len <= base_offset:
            raise ValueError(
                f"'input_len' must be higher than the base prompt length "
588
589
                f"({base_offset})."
            )
590
591
592

        # Determine how many poem lines to use.
        num_input_lines = round((input_len - base_offset) / avg_len)
593
        num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
594
595
596
        prefix_lines = self.data[:num_prefix_lines]

        samples = []
597
        while len(samples) < num_requests:
598
599
600
            extra_lines = random.choices(
                self.data, k=num_input_lines - num_prefix_lines
            )
601
602
603
            prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
            msg = [{"role": "user", "content": prompt}]
            prompt_formatted = tokenizer.apply_chat_template(
604
605
                msg, add_generation_prompt=True, tokenize=False
            )
606
            prompt_len = len(tokenizer(prompt_formatted).input_ids)
607
608
609
            if prompt_len <= input_len:
                samples.append(
                    SampleRequest(
610
                        prompt=prompt_formatted if return_prompt_formatted else prompt,
611
612
                        prompt_len=prompt_len,
                        expected_output_len=output_len,
613
614
                    )
                )
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        return samples


# -----------------------------------------------------------------------------
# BurstGPT Dataset Implementation
# -----------------------------------------------------------------------------


class BurstGPTDataset(BenchmarkDataset):
    """
    Implements the BurstGPT dataset.  Loads data from a CSV file and generates
    sample requests based on synthetic prompt generation. Only rows with Model
    "GPT-4" and positive response tokens are used.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

634
635
636
    def load_data(
        self,
    ):
637
638
639
640
641
642
643
644
645
646
647
648
649
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        df = pd.read_csv(self.dataset_path)
        # Filter to keep only GPT-4 rows.
        gpt4_df = df[df["Model"] == "GPT-4"]
        # Remove failed requests (where Response tokens is 0 or less).
        gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
        # Sample the desired number of rows.
        self.data = gpt4_df

    def _sample_loaded_data(self, num_requests: int) -> list:
        if num_requests <= len(self.data):
650
            data = self.data.sample(n=num_requests, random_state=self.random_seed)
651
652
653
654
655
656
657
658
659
        else:
            data = self.data.sample(
                n=num_requests,
                random_state=self.random_seed,
                replace=True,
            )
        # Convert the dataframe to a list of lists.
        return data.values.tolist()

660
661
662
663
664
665
666
667
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        max_loras: Optional[int] = None,
        lora_path: Optional[str] = None,
        **kwargs,
    ) -> list[SampleRequest]:
668
669
670
671
672
673
        samples = []
        data = self._sample_loaded_data(num_requests=num_requests)
        for i in range(num_requests):
            input_len = int(data[i][2])
            output_len = int(data[i][3])
            lora_req, tokenizer = self.get_random_lora_request(
674
675
                tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
            )
676
677
678
679
680
681
682
683
684
685
686
            vocab_size = tokenizer.vocab_size
            # Generate a synthetic prompt: a list of token IDs computed as (i +
            # j) modulo vocab_size.
            token_ids = [(i + j) % vocab_size for j in range(input_len)]
            prompt = tokenizer.decode(token_ids)
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=input_len,
                    expected_output_len=output_len,
                    lora_request=lora_req,
687
688
                )
            )
689
690
691
692
        return samples


# -----------------------------------------------------------------------------
693
# HuggingFace Dataset Base Implementation
694
695
# -----------------------------------------------------------------------------
class HuggingFaceDataset(BenchmarkDataset):
696
697
698
    """Base class for datasets hosted on HuggingFace."""

    SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
699
700
701

    def __init__(
        self,
702
        dataset_path: str,
703
        dataset_split: str,
704
        no_stream: bool = False,
705
706
707
        dataset_subset: Optional[str] = None,
        **kwargs,
    ) -> None:
708
709
        super().__init__(dataset_path=dataset_path, **kwargs)

710
711
        self.dataset_split = dataset_split
        self.dataset_subset = dataset_subset
712
        self.load_stream = not no_stream
713
714
715
        self.load_data()

    def load_data(self) -> None:
716
        """Load data from HuggingFace datasets."""
717
718
719
720
        self.data = load_dataset(
            self.dataset_path,
            name=self.dataset_subset,
            split=self.dataset_split,
721
            streaming=self.load_stream,
722
        )
723
724
725
726
727
728
729
730
731
732
        self.data = self.data.shuffle(seed=self.random_seed)


# -----------------------------------------------------------------------------
# Conversation Dataset Implementation
# -----------------------------------------------------------------------------


class ConversationDataset(HuggingFaceDataset):
    """Dataset for conversation data with multimodal support."""
733

734
    SUPPORTED_DATASET_PATHS = {
735
736
        "lmms-lab/LLaVA-OneVision-Data",
        "Aeala/ShareGPT_Vicuna_unfiltered",
737
    }
738
    IS_MULTIMODAL = True
739

740
741
742
743
744
745
746
747
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
748
        # Filter examples with at least 2 conversations
749
        filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
750
751
752
        sampled_requests = []
        dynamic_output = output_len is None

753
        for item in filtered_data:
754
755
756
757
758
759
760
761
762
763
764
            if len(sampled_requests) >= num_requests:
                break
            conv = item["conversations"]
            prompt, completion = conv[0]["value"], conv[1]["value"]

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
765
            if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
766
                continue
767
            mm_content = process_image(item["image"]) if "image" in item else None
768
769
770
771
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len and output len
772
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
773
774
775
776
777
778
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
779
780
                )
            )
781
        self.maybe_oversample_requests(sampled_requests, num_requests)
782
783
784
785
786
787
788
789
        return sampled_requests


# -----------------------------------------------------------------------------
# Vision Arena Dataset Implementation
# -----------------------------------------------------------------------------


790
class VisionArenaDataset(HuggingFaceDataset):
791
792
793
794
795
    """
    Vision Arena Dataset.
    """

    DEFAULT_OUTPUT_LEN = 128
796
    SUPPORTED_DATASET_PATHS = {
797
798
        "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
        "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
799
    }
800
    IS_MULTIMODAL = True
801

802
803
804
805
806
807
808
809
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
810
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
811
812
813
814
        sampled_requests = []
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
815
816
            parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
            if parser_fn is None:
817
                raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
818
            prompt = parser_fn(item)
819
            mm_content = process_image(item["images"][0])
820
821
822
823
824
            prompt_len = len(tokenizer(prompt).input_ids)
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len
825
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
826
827
828
829
830
831
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
832
833
                )
            )
834
        self.maybe_oversample_requests(sampled_requests, num_requests)
835
        return sampled_requests
836
837
838
839
840
841
842
843
844
845
846
847


# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------


class InstructCoderDataset(HuggingFaceDataset):
    """
    InstructCoder Dataset.
    https://huggingface.co/datasets/likaixin/InstructCoder

848
849
850
    InstructCoder is the dataset designed for general code editing.  It consists
    of 114,239 instruction-input-output triplets, and covers multiple distinct
    code editing scenario.
851
852
853
    """

    DEFAULT_OUTPUT_LEN = 200  # this is the average default output length
854
855
856
    SUPPORTED_DATASET_PATHS = {
        "likaixin/InstructCoder",
    }
857

858
859
860
861
862
863
864
865
866
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
867
868
869
870
        sampled_requests = []
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
871
872
873
874
875
876
877
878
879
            prompt = f"{item['input']}\n\n{item['instruction']} Just output \
            the code, do not include any explanation."

            # apply template
            prompt = tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
880
881
882
883
884
885
            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
886
887
                )
            )
888
889
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests
890
891


892
893
894
895
896
897
898
899
900
901
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------


class MTBenchDataset(HuggingFaceDataset):
    """
    MT-Bench Dataset.
    https://huggingface.co/datasets/philschmid/mt-bench

902
    We create a single turn dataset for MT-Bench.
903
904
    This is similar to Spec decoding benchmark setup in vLLM
    https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
905
    """  # noqa: E501
906
907
908
909
910
911

    DEFAULT_OUTPUT_LEN = 256  # avg len used in SD bench in vLLM
    SUPPORTED_DATASET_PATHS = {
        "philschmid/mt-bench",
    }

912
913
914
915
916
917
918
919
920
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
921
922
923
924
925
        sampled_requests = []

        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
926
            prompt = item["turns"][0]
927
928

            # apply template
929
930
931
932
933
            prompt = tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
934
935
936
937
938
939
940

            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
941
942
                )
            )
943
944
945
946
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests


947
948
949
950
951
952
953
954
955
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------


class AIMODataset(HuggingFaceDataset):
    """
    Dataset class for processing a AIMO dataset with reasoning questions.
    """
956

957
    SUPPORTED_DATASET_PATHS = {
958
959
960
        "AI-MO/aimo-validation-aime",
        "AI-MO/NuminaMath-1.5",
        "AI-MO/NuminaMath-CoT",
961
962
    }

963
964
965
966
967
968
969
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        **kwargs,
    ) -> list:
970
971
972
973
974
975
        sampled_requests = []
        dynamic_output = output_len is None

        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
976
            prompt, completion = item["problem"], item["solution"]
977
978
979
980
981
982
983

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
984
985
986
            if dynamic_output and not is_valid_sequence(
                prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
            ):
987
988
989
990
991
992
993
                continue
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=None,
994
995
                )
            )
996
997
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests
998
999


1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------


zeta_prompt = """### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.

### User Edits:

{}

### User Excerpt:

{}

### Response:

1018
"""  # noqa: E501
1019
1020
1021


def _format_zeta_prompt(
1022
1023
    sample: dict, original_start_marker: str = "<|editable_region_start|>"
) -> dict:
1024
    """Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
1025
1026
1027

    This function formats examples from the NEP dataset
    into prompts and expected outputs. It could be
1028
    further extended to support more NEP datasets.
1029

1030
    Args:
1031
        sample: The dataset sample containing events,
1032
            inputs, and outputs.
1033
1034
        original_start_marker: The marker indicating the
            start of the editable region. Defaults to
1035
            "<|editable_region_start|>".
1036

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    Returns:
        A dictionary with the formatted prompts and expected outputs.
    """
    events = sample["events"]
    input = sample["input"]
    output = sample["output"]
    prompt = zeta_prompt.format(events, input)

    # following the original implementation, extract the focused region
    # from the raw output
    output_start_index = output.find(original_start_marker)
    output_focused_region = output[output_start_index:]
    expected_output = output_focused_region

    return {"prompt": prompt, "expected_output": expected_output}


class NextEditPredictionDataset(HuggingFaceDataset):
    """
    Dataset class for processing a Next Edit Prediction dataset.
    """

    SUPPORTED_DATASET_PATHS = {
        "zed-industries/zeta",
    }
    MAPPING_PROMPT_FUNCS = {
        "zed-industries/zeta": _format_zeta_prompt,
    }

1066
1067
    def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
        formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        if formatting_prompt_func is None:
            raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
        samples = []
        for sample in self.data:
            sample = formatting_prompt_func(sample)
            samples.append(
                SampleRequest(
                    prompt=sample["prompt"],
                    prompt_len=len(tokenizer(sample["prompt"]).input_ids),
                    expected_output_len=len(
1078
1079
1080
1081
                        tokenizer(sample["expected_output"]).input_ids
                    ),
                )
            )
1082
1083
1084
1085
1086
1087
            if len(samples) >= num_requests:
                break
        self.maybe_oversample_requests(samples, num_requests)
        return samples


1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------


class ASRDataset(HuggingFaceDataset):
    """
    Dataset class for processing a ASR dataset for transcription.
    Tested on the following set:

    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | Dataset        | Domain                                 | Speaking Style           | hf-subset                   |
    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | TED-LIUM       | TED talks                              | Oratory                  | release1, release2, release3|
    |                |                                        |                          | release3-speaker-adaptation |
    | VoxPopuli      | European Parliament                    | Oratory                  | en, de, it, fr,  ...        |
    | LibriSpeech    | Audiobook                              | Narrated                 | "LIUM/tedlium"              |
    | GigaSpeech     | Audiobook, podcast, YouTube            | Narrated, spontaneous    | xs, s, m, l, xl, dev, test  |
    | SPGISpeech     | Financial meetings                     | Oratory, spontaneous     | S, M, L, dev, test          |
    | AMI            | Meetings                               | Spontaneous              | ihm, sdm                    |
    +----------------+----------------------------------------+--------------------------+-----------------------------+

1110
1111
    """  # noqa: E501

1112
    SUPPORTED_DATASET_PATHS = {
1113
1114
1115
1116
1117
1118
        "openslr/librispeech_asr",
        "facebook/voxpopuli",
        "LIUM/tedlium",
        "edinburghcstr/ami",
        "speechcolab/gigaspeech",
        "kensho/spgispeech",
1119
1120
1121
1122
1123
1124
    }

    DEFAULT_OUTPUT_LEN = 128
    IS_MULTIMODAL = True

    # TODO Whisper-specific. Abstract interface when more models are supported.
1125
    TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
    skip_long_audios: bool = True

    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        **kwargs,
    ) -> list:
        import librosa
1136
1137

        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
        prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
        prompt_len = len(tokenizer(prompt).input_ids)
        sampled_requests = []
        skipped = 0
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
            audio = item["audio"]
            y, sr = audio["array"], audio["sampling_rate"]
            duration_s = librosa.get_duration(y=y, sr=sr)
            # Whisper max supported duration
            if self.skip_long_audios and duration_s > 30:
                skipped += 1
                continue

            mm_content = {"audio": (y, sr)}
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
1160
1161
                )
            )
1162
        if skipped:
1163
1164
1165
1166
1167
1168
            logger.warning(
                "%d samples discarded from dataset due to"
                " their length being greater than"
                " what Whisper supports.",
                skipped,
            )
1169
1170
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests