benchmark_dataset.py 44.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
  - ShareGPT
  - Random (synthetic)
  - Sonnet
  - BurstGPT
  - HuggingFace
  - VisionArena
"""

import base64
import io
import json
18
import logging
19
20
21
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
22
from copy import deepcopy
23
24
from dataclasses import dataclass
from functools import cache
25
26
from io import BytesIO
from typing import Any, Callable, Optional, Union
27
28
29
30
31
32
33
34
35
36

import numpy as np
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import PreTrainedTokenizerBase

from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
37
from vllm.multimodal.image import convert_image_mode
38
39
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer

40
41
logger = logging.getLogger(__name__)

42
43
44
45
46
47
48
49
50
51
52
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------


@dataclass
class SampleRequest:
    """
    Represents a single inference request for benchmarking.
    """

53
    prompt: Union[str, Any]
54
55
    prompt_len: int
    expected_output_len: int
56
    multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
57
    lora_request: Optional[LoRARequest] = None
58
    request_id: Optional[str] = None
59
60
61
62
63
64
65
66
67


# -----------------------------------------------------------------------------
# Benchmark Dataset Base Class
# -----------------------------------------------------------------------------


class BenchmarkDataset(ABC):
    DEFAULT_SEED = 0
68
    IS_MULTIMODAL = False
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    def __init__(
        self,
        dataset_path: Optional[str] = None,
        random_seed: int = DEFAULT_SEED,
    ) -> None:
        """
        Initialize the BenchmarkDataset with an optional dataset path and random
        seed.  Args:
            dataset_path (Optional[str]): Path to the dataset. If None, it
            indicates that a default or random dataset might be used.
            random_seed (int): Seed value for reproducible shuffling or
            sampling. Defaults to DEFAULT_SEED.
        """
        self.dataset_path = dataset_path
        # Set the random seed, ensuring that a None value is replaced with the
        # default seed.
86
        self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
87
88
        self.data = None

89
    def apply_multimodal_chat_transformation(
90
91
        self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
    ) -> list[dict]:
92
93
        """
        Transform a prompt and optional multimodal content into a chat format.
94
95
        This method is used for chat models that expect a specific conversation
        format.
96
97
98
99
100
101
        """
        content = [{"text": prompt, "type": "text"}]
        if mm_content is not None:
            content.append(mm_content)
        return [{"role": "user", "content": content}]

102
103
104
    def load_data(self) -> None:
        """
        Load data from the dataset path into self.data.
105

106
107
        This method must be overridden by subclasses since the method to load
        data will vary depending on the dataset format and source.
108

109
110
111
112
        Raises:
            NotImplementedError: If a subclass does not implement this method.
        """
        # TODO (jenniferzhao): add support for downloading data
113
        raise NotImplementedError("load_data must be implemented in subclasses.")
114
115
116
117
118
119
120
121
122
123

    def get_random_lora_request(
        self,
        tokenizer: PreTrainedTokenizerBase,
        max_loras: Optional[int] = None,
        lora_path: Optional[str] = None,
    ) -> tuple[Optional[LoRARequest], AnyTokenizer]:
        """
        Optionally select a random LoRA request and return its associated
        tokenizer.
124

125
126
127
        This method is used when LoRA parameters are provided.  It randomly
        selects a LoRA based on max_loras and retrieves a cached tokenizer for
        that LoRA if available. Otherwise, it returns the base tokenizer.
128

129
130
131
132
133
134
        Args:
            tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
            LoRA is selected.  max_loras (Optional[int]): The maximum number of
            LoRAs available. If None, LoRA is not used.  lora_path
            (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
            is not used.
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        Returns:
            tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
            element is a LoRARequest (or None if not applicable) and the second
            element is the tokenizer associated with the LoRA request (or the
            base tokenizer).
        """
        if max_loras is None or lora_path is None:
            return None, tokenizer

        # Generate a random LoRA ID in the range [1, max_loras].
        lora_id = random.randint(1, max_loras)
        lora_request = LoRARequest(
            lora_name=str(lora_id),
            lora_int_id=lora_id,
            lora_path=lora_path_on_disk(lora_path),
        )
        if lora_id not in lora_tokenizer_cache:
            lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
        # Return lora_request and the cached tokenizer if available; otherwise,
        # return the base tokenizer
        return lora_request, lora_tokenizer_cache[lora_id] or tokenizer

    @abstractmethod
159
    def sample(
160
161
162
163
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        request_id_prefix: str = "",
164
    ) -> list[SampleRequest]:
165
166
        """
        Abstract method to generate sample requests from the dataset.
167

168
169
        Subclasses must override this method to implement dataset-specific logic
        for generating a list of SampleRequest objects.
170

171
172
173
174
        Args:
            tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
             for processing the dataset's text.
            num_requests (int): The number of sample requests to generate.
175
            request_id_prefix (str) The prefix of request_id.
176

177
178
179
180
181
182
        Returns:
            list[SampleRequest]: A list of sample requests generated from the
            dataset.
        """
        raise NotImplementedError("sample must be implemented in subclasses.")

183
    def maybe_oversample_requests(
184
185
186
187
        self,
        requests: list[SampleRequest],
        num_requests: int,
        request_id_prefix: str = "",
188
    ) -> None:
189
190
191
192
193
194
        """
        Oversamples the list of requests if its size is less than the desired
        number.

        Args:
            requests (List[SampleRequest]): The current list of sampled
195
196
197
            requests.
            num_requests (int): The target number of requests.
            request_id_prefix (str) The prefix of the request ids.
198
199
200
        """
        if len(requests) < num_requests:
            random.seed(self.random_seed)
201
202
203
204
205
206
            additional = deepcopy(
                random.choices(requests, k=num_requests - len(requests))
            )
            for i in range(len(additional)):
                req = additional[i]
                req.request_id = request_id_prefix + str(len(requests) + i)
207
            requests.extend(additional)
208
            logger.info("Oversampled requests to reach %d total samples.", num_requests)
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

# -----------------------------------------------------------------------------
# Utility Functions and Global Caches
# -----------------------------------------------------------------------------


def is_valid_sequence(
    prompt_len: int,
    output_len: int,
    min_len: int = 4,
    max_prompt_len: int = 1024,
    max_total_len: int = 2048,
    skip_min_output_len_check: bool = False,
) -> bool:
    """
    Validate a sequence based on prompt and output lengths.

    Default pruning criteria are copied from the original `sample_hf_requests`
    and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
    from `sample_requests` in benchmark_throughput.py.
    """
    # Check for invalid conditions
    prompt_too_short = prompt_len < min_len
233
    output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
234
235
236
237
    prompt_too_long = prompt_len > max_prompt_len
    combined_too_long = (prompt_len + output_len) > max_total_len

    # Return True if none of the invalid conditions are met
238
239
240
    return not (
        prompt_too_short or output_too_short or prompt_too_long or combined_too_long
    )
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255


@cache
def lora_path_on_disk(lora_path: str) -> str:
    return get_adapter_absolute_path(lora_path)


# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}


def process_image(image: Any) -> Mapping[str, Any]:
    """
    Process a single image input and return a multimedia content dictionary.

256
    Supports three input types:
257

258
259
260
261
262
263
264
265
266
267
    1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
       containing raw image data.  - Loads the bytes as a PIL.Image.Image.

    2. PIL.Image.Image input: - Converts the image to RGB.  - Saves the image as
       a JPEG in memory.  - Encodes the JPEG data as a base64 string.  - Returns
       a dictionary with the image as a base64 data URL.

    3. String input: - Treats the string as a URL or local file path.  -
       Prepends "file://" if the string doesn't start with "http://" or
       "file://".  - Returns a dictionary with the image URL.
268
269

    Raises:
270
        ValueError: If the input is not a supported type.
271
    """
272
273
    if isinstance(image, dict) and "bytes" in image:
        image = Image.open(BytesIO(image["bytes"]))
274
    if isinstance(image, Image.Image):
275
        image = convert_image_mode(image, "RGB")
276
277
        with io.BytesIO() as image_data:
            image.save(image_data, format="JPEG")
278
            image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
279
280
        return {
            "type": "image_url",
281
            "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
282
283
284
        }

    if isinstance(image, str):
285
286
287
        image_url = (
            image if image.startswith(("http://", "file://")) else f"file://{image}"
        )
288
289
        return {"type": "image_url", "image_url": {"url": image_url}}

290
291
292
293
    raise ValueError(
        f"Invalid image input {image}. Must be a PIL.Image.Image"
        " or str or dictionary with raw image bytes."
    )
294
295
296
297
298
299
300
301
302
303


# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------


class RandomDataset(BenchmarkDataset):
    # Default values copied from benchmark_serving.py for the random dataset.
    DEFAULT_PREFIX_LEN = 0
304
    DEFAULT_RANGE_RATIO = 0.0
305
306
307
308
309
310
311
312
313
    DEFAULT_INPUT_LEN = 1024
    DEFAULT_OUTPUT_LEN = 128

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

314
315
316
317
318
319
320
321
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        range_ratio: float = DEFAULT_RANGE_RATIO,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
322
        request_id_prefix: str = "",
323
324
        **kwargs,
    ) -> list[SampleRequest]:
325
326
327
328
329
        # Enforce range_ratio < 1
        assert range_ratio < 1.0, (
            "random_range_ratio must be < 1.0 to ensure a valid sampling range"
        )

330
        vocab_size = tokenizer.vocab_size
331
332
        num_special_tokens = tokenizer.num_special_tokens_to_add()
        real_input_len = input_len - num_special_tokens
333

334
335
336
337
338
        prefix_token_ids = (
            np.random.randint(0, vocab_size, size=prefix_len).tolist()
            if prefix_len > 0
            else []
        )
339

340
        # New sampling logic: [X * (1 - b), X * (1 + b)]
341
342
        input_low = int(real_input_len * (1 - range_ratio))
        input_high = int(real_input_len * (1 + range_ratio))
343
        output_low = int(output_len * (1 - range_ratio))
344
345
346
        # Ensure the lower bound for output length is at least 1 to prevent
        # sampling 0 tokens, which can cause request failures.
        output_low = max(output_low, 1)
347
348
349
350
        output_high = int(output_len * (1 + range_ratio))

        # Add logging for debugging
        logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
351
352
353
354
        logger.info("Sampling output_len from [%s, %s]", output_low, output_high)

        input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
        output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
355
356
357
358
        offsets = np.random.randint(0, vocab_size, size=num_requests)

        requests = []
        for i in range(num_requests):
359
360
361
            inner_seq = (
                (offsets[i] + i + np.arange(input_lens[i])) % vocab_size
            ).tolist()
362
363
            token_sequence = prefix_token_ids + inner_seq
            prompt = tokenizer.decode(token_sequence)
364
365
366
367
368
369
370
371
            # After decoding the prompt we have to encode and decode it again.
            # This is done because in some cases N consecutive tokens
            # give a string tokenized into != N number of tokens.
            # For example for GPT2Tokenizer:
            # [6880, 6881] -> ['Ġcalls', 'here'] ->
            # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
            # To avoid uncontrolled change of the prompt length,
            # the encoded sequence is truncated before being decode again.
372
            total_input_len = prefix_len + int(input_lens[i])
373
            re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
374
                :total_input_len
375
            ]
376
            prompt = tokenizer.decode(re_encoded_sequence)
377
            total_input_len = len(re_encoded_sequence)
378
379
380
381
382
            requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=total_input_len,
                    expected_output_len=int(output_lens[i]),
383
                    request_id=request_id_prefix + str(i),
384
385
                )
            )
386

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        return requests


# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
# -----------------------------------------------------------------------------


class ShareGPTDataset(BenchmarkDataset):
    """
    Implements the ShareGPT dataset.  Loads data from a JSON file and generates
    sample requests based on conversation turns.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = json.load(f)
        # Filter entries with at least two conversation turns.
        self.data = [
413
414
            entry
            for entry in self.data
415
416
417
418
419
            if "conversations" in entry and len(entry["conversations"]) >= 2
        ]
        random.seed(self.random_seed)
        random.shuffle(self.data)

420
421
422
423
424
425
426
427
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        lora_path: Optional[str] = None,
        max_loras: Optional[int] = None,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
428
        request_id_prefix: str = "",
429
430
        **kwargs,
    ) -> list:
431
        samples: list = []
432
        ind = 0
433
434
435
        for entry in self.data:
            if len(samples) >= num_requests:
                break
436
437
438
439
            prompt, completion = (
                entry["conversations"][0]["value"],
                entry["conversations"][1]["value"],
            )
440
441

            lora_request, tokenizer = self.get_random_lora_request(
442
443
                tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
            )
444
445
446
            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
447
448
449
450
451
452
            new_output_len = len(completion_ids) if output_len is None else output_len
            if not is_valid_sequence(
                prompt_len,
                new_output_len,
                skip_min_output_len_check=output_len is not None,
            ):
453
                continue
454
455
456
457
458
            # TODO: Also support ShareGPT4Video.
            if image_path := entry.get("image"):
                mm_content = process_image(image_path)
            else:
                mm_content = None
459
            if enable_multimodal_chat:
460
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
461
462
463
464
465
466
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=new_output_len,
                    lora_request=lora_request,
467
                    multi_modal_data=mm_content,
468
                    request_id=request_id_prefix + str(ind),
469
470
                )
            )
471
472
            ind += 1
        self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
473
474
475
        return samples


476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
# -----------------------------------------------------------------------------
# Custom Dataset Implementation
# -----------------------------------------------------------------------------


class CustomDataset(BenchmarkDataset):
    """
    Implements the Custom dataset.  Loads data from a JSONL file and generates
    sample requests based on conversation turns. E.g.,
    ```
    {"prompt": "What is the capital of India?"}
    {"prompt": "What is the capital of Iran?"}
    {"prompt": "What is the capital of China?"}
    ```
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        # self.data will be a list of dictionaries
        # e.g., [{"prompt": "What is the capital of India?"}, ...]
        # This will be the standardized format which load_data()
        # has to convert into depending on the filetype of dataset_path.
        # sample() will assume this standardized format of self.data
        self.data = []

        # Load the JSONL file
        if self.dataset_path.endswith(".jsonl"):
            jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)

            # check if the JSONL file has a 'prompt' column
            if "prompt" not in jsonl_data.columns:
                raise ValueError("JSONL file must contain a 'prompt' column.")

            # Convert each row to a dictionary and append to self.data
            # This will convert the DataFrame to a list of dictionaries
            # where each dictionary corresponds to a row in the DataFrame.
            # This is the standardized format we want for self.data
            for _, row in jsonl_data.iterrows():
                self.data.append(row.to_dict())
        else:
            raise NotImplementedError(
                "Only JSONL format is supported for CustomDataset."
            )

        random.seed(self.random_seed)
        random.shuffle(self.data)

    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        lora_path: Optional[str] = None,
        max_loras: Optional[int] = None,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        skip_chat_template: bool = False,
538
        request_id_prefix: str = "",
539
540
541
        **kwargs,
    ) -> list:
        sampled_requests = []
542
        for i, item in enumerate(self.data):
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
            if len(sampled_requests) >= num_requests:
                break
            prompt = item["prompt"]

            # apply template
            if not skip_chat_template:
                prompt = tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    add_generation_prompt=True,
                    tokenize=False,
                )

            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
561
                    request_id=request_id_prefix + str(i),
562
563
                )
            )
564
565
566
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix
        )
567
568
569
570

        return sampled_requests


571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------


class SonnetDataset(BenchmarkDataset):
    """
    Simplified implementation of the Sonnet dataset.  Loads poem lines from a
    text file and generates sample requests.  Default values here copied from
    `benchmark_serving.py` for the sonnet dataset.
    """

    DEFAULT_PREFIX_LEN = 200
    DEFAULT_INPUT_LEN = 550
    DEFAULT_OUTPUT_LEN = 150

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if not self.dataset_path:
            raise ValueError("dataset_path must be provided.")
        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = f.readlines()

600
601
602
603
604
605
606
607
    def sample(
        self,
        tokenizer,
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
        return_prompt_formatted: bool = False,
608
        request_id_prefix: str = "",
609
610
        **kwargs,
    ) -> list:
611
612
        # Calculate average token length for a poem line.
        tokenized_lines = [tokenizer(line).input_ids for line in self.data]
613
        avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
614
615
616
617

        # Build the base prompt.
        base_prompt = "Pick as many lines as you can from these poem lines:\n"
        base_msg = [{"role": "user", "content": base_prompt}]
618
619
620
        base_fmt = tokenizer.apply_chat_template(
            base_msg, add_generation_prompt=True, tokenize=False
        )
621
622
623
624
        base_offset = len(tokenizer(base_fmt).input_ids)
        if input_len <= base_offset:
            raise ValueError(
                f"'input_len' must be higher than the base prompt length "
625
626
                f"({base_offset})."
            )
627
628
629

        # Determine how many poem lines to use.
        num_input_lines = round((input_len - base_offset) / avg_len)
630
        num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
631
632
633
        prefix_lines = self.data[:num_prefix_lines]

        samples = []
634
        ind = 0
635
        while len(samples) < num_requests:
636
637
638
            extra_lines = random.choices(
                self.data, k=num_input_lines - num_prefix_lines
            )
639
640
641
            prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
            msg = [{"role": "user", "content": prompt}]
            prompt_formatted = tokenizer.apply_chat_template(
642
643
                msg, add_generation_prompt=True, tokenize=False
            )
644
            prompt_len = len(tokenizer(prompt_formatted).input_ids)
645

646
647
648
            if prompt_len <= input_len:
                samples.append(
                    SampleRequest(
649
                        prompt=prompt_formatted if return_prompt_formatted else prompt,
650
651
                        prompt_len=prompt_len,
                        expected_output_len=output_len,
652
                        request_id=request_id_prefix + str(ind),
653
654
                    )
                )
655
                ind += 1
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        return samples


# -----------------------------------------------------------------------------
# BurstGPT Dataset Implementation
# -----------------------------------------------------------------------------


class BurstGPTDataset(BenchmarkDataset):
    """
    Implements the BurstGPT dataset.  Loads data from a CSV file and generates
    sample requests based on synthetic prompt generation. Only rows with Model
    "GPT-4" and positive response tokens are used.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

675
676
677
    def load_data(
        self,
    ):
678
679
680
681
682
683
684
685
686
687
688
689
690
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        df = pd.read_csv(self.dataset_path)
        # Filter to keep only GPT-4 rows.
        gpt4_df = df[df["Model"] == "GPT-4"]
        # Remove failed requests (where Response tokens is 0 or less).
        gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
        # Sample the desired number of rows.
        self.data = gpt4_df

    def _sample_loaded_data(self, num_requests: int) -> list:
        if num_requests <= len(self.data):
691
            data = self.data.sample(n=num_requests, random_state=self.random_seed)
692
693
694
695
696
697
698
699
700
        else:
            data = self.data.sample(
                n=num_requests,
                random_state=self.random_seed,
                replace=True,
            )
        # Convert the dataframe to a list of lists.
        return data.values.tolist()

701
702
703
704
705
706
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        max_loras: Optional[int] = None,
        lora_path: Optional[str] = None,
707
        request_id_prefix: str = "",
708
709
        **kwargs,
    ) -> list[SampleRequest]:
710
711
712
713
714
715
        samples = []
        data = self._sample_loaded_data(num_requests=num_requests)
        for i in range(num_requests):
            input_len = int(data[i][2])
            output_len = int(data[i][3])
            lora_req, tokenizer = self.get_random_lora_request(
716
717
                tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
            )
718
719
720
721
722
723
724
725
726
727
728
            vocab_size = tokenizer.vocab_size
            # Generate a synthetic prompt: a list of token IDs computed as (i +
            # j) modulo vocab_size.
            token_ids = [(i + j) % vocab_size for j in range(input_len)]
            prompt = tokenizer.decode(token_ids)
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=input_len,
                    expected_output_len=output_len,
                    lora_request=lora_req,
729
                    request_id=request_id_prefix + str(i),
730
731
                )
            )
732
733
734
735
        return samples


# -----------------------------------------------------------------------------
736
# HuggingFace Dataset Base Implementation
737
738
# -----------------------------------------------------------------------------
class HuggingFaceDataset(BenchmarkDataset):
739
740
741
    """Base class for datasets hosted on HuggingFace."""

    SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
742
743
744

    def __init__(
        self,
745
        dataset_path: str,
746
        dataset_split: str,
747
        no_stream: bool = False,
748
749
750
        dataset_subset: Optional[str] = None,
        **kwargs,
    ) -> None:
751
752
        super().__init__(dataset_path=dataset_path, **kwargs)

753
754
        self.dataset_split = dataset_split
        self.dataset_subset = dataset_subset
755
        self.load_stream = not no_stream
756
757
758
        self.load_data()

    def load_data(self) -> None:
759
        """Load data from HuggingFace datasets."""
760
761
762
763
        self.data = load_dataset(
            self.dataset_path,
            name=self.dataset_subset,
            split=self.dataset_split,
764
            streaming=self.load_stream,
765
        )
766
767
768
769
770
771
772
773
774
775
        self.data = self.data.shuffle(seed=self.random_seed)


# -----------------------------------------------------------------------------
# Conversation Dataset Implementation
# -----------------------------------------------------------------------------


class ConversationDataset(HuggingFaceDataset):
    """Dataset for conversation data with multimodal support."""
776

777
    SUPPORTED_DATASET_PATHS = {
778
779
        "lmms-lab/LLaVA-OneVision-Data",
        "Aeala/ShareGPT_Vicuna_unfiltered",
780
    }
781
    IS_MULTIMODAL = True
782

783
784
785
786
787
788
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
789
        request_id_prefix: str = "",
790
791
        **kwargs,
    ) -> list:
792
        # Filter examples with at least 2 conversations
793
        filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
794
795
        sampled_requests = []
        dynamic_output = output_len is None
796
        ind = 0
797

798
        for item in filtered_data:
799
800
801
802
803
804
805
806
807
808
809
            if len(sampled_requests) >= num_requests:
                break
            conv = item["conversations"]
            prompt, completion = conv[0]["value"], conv[1]["value"]

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
810
            if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
811
                continue
812
            mm_content = process_image(item["image"]) if "image" in item else None
813
814
815
816
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len and output len
817
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
818
819
820
821
822
823
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
824
                    request_id=request_id_prefix + str(ind),
825
826
                )
            )
827
828
829
830
            ind += 1
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix
        )
831
832
833
834
835
836
837
838
        return sampled_requests


# -----------------------------------------------------------------------------
# Vision Arena Dataset Implementation
# -----------------------------------------------------------------------------


839
class VisionArenaDataset(HuggingFaceDataset):
840
841
842
843
844
    """
    Vision Arena Dataset.
    """

    DEFAULT_OUTPUT_LEN = 128
845
    SUPPORTED_DATASET_PATHS = {
846
847
        "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
        "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
848
    }
849
    IS_MULTIMODAL = True
850

851
852
853
854
855
856
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
857
        request_id_prefix: str = "",
858
859
        **kwargs,
    ) -> list:
860
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
861
        sampled_requests = []
862
        for i, item in enumerate(self.data):
863
864
            if len(sampled_requests) >= num_requests:
                break
865
866
            parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
            if parser_fn is None:
867
                raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
868
            prompt = parser_fn(item)
869
            mm_content = process_image(item["images"][0])
870
871
872
873
874
            prompt_len = len(tokenizer(prompt).input_ids)
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len
875
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
876
877
878
879
880
881
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
882
                    request_id=request_id_prefix + str(i),
883
884
                )
            )
885
886
887
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix
        )
888
        return sampled_requests
889
890
891
892
893
894
895
896
897
898
899
900


# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------


class InstructCoderDataset(HuggingFaceDataset):
    """
    InstructCoder Dataset.
    https://huggingface.co/datasets/likaixin/InstructCoder

901
902
903
    InstructCoder is the dataset designed for general code editing.  It consists
    of 114,239 instruction-input-output triplets, and covers multiple distinct
    code editing scenario.
904
905
906
    """

    DEFAULT_OUTPUT_LEN = 200  # this is the average default output length
907
908
909
    SUPPORTED_DATASET_PATHS = {
        "likaixin/InstructCoder",
    }
910

911
912
913
914
915
916
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
917
        request_id_prefix: str = "",
918
919
920
        **kwargs,
    ) -> list:
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
921
        sampled_requests = []
922
        for i, item in enumerate(self.data):
923
924
            if len(sampled_requests) >= num_requests:
                break
925
926
927
928
929
930
931
932
933
            prompt = f"{item['input']}\n\n{item['instruction']} Just output \
            the code, do not include any explanation."

            # apply template
            prompt = tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
934
935
936
937
938
939
            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
940
                    request_id=request_id_prefix + str(i),
941
942
                )
            )
943
944
945
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix
        )
946
        return sampled_requests
947
948


949
950
951
952
953
954
955
956
957
958
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------


class MTBenchDataset(HuggingFaceDataset):
    """
    MT-Bench Dataset.
    https://huggingface.co/datasets/philschmid/mt-bench

959
    We create a single turn dataset for MT-Bench.
960
961
    This is similar to Spec decoding benchmark setup in vLLM
    https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
962
    """  # noqa: E501
963
964
965
966
967
968

    DEFAULT_OUTPUT_LEN = 256  # avg len used in SD bench in vLLM
    SUPPORTED_DATASET_PATHS = {
        "philschmid/mt-bench",
    }

969
970
971
972
973
974
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
975
        request_id_prefix: str = "",
976
977
978
        **kwargs,
    ) -> list:
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
979
980
        sampled_requests = []

981
        for i, item in enumerate(self.data):
982
983
            if len(sampled_requests) >= num_requests:
                break
984
            prompt = item["turns"][0]
985
986

            # apply template
987
988
989
990
991
            prompt = tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
992
993
994
995
996
997
998

            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
999
                    request_id=request_id_prefix + str(i),
1000
1001
                )
            )
1002
1003
1004
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix
        )
1005
1006
1007
        return sampled_requests


1008
1009
1010
1011
1012
1013
1014
1015
1016
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------


class AIMODataset(HuggingFaceDataset):
    """
    Dataset class for processing a AIMO dataset with reasoning questions.
    """
1017

1018
    SUPPORTED_DATASET_PATHS = {
1019
1020
1021
        "AI-MO/aimo-validation-aime",
        "AI-MO/NuminaMath-1.5",
        "AI-MO/NuminaMath-CoT",
1022
1023
    }

1024
1025
1026
1027
1028
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
1029
        request_id_prefix: str = "",
1030
1031
        **kwargs,
    ) -> list:
1032
1033
        sampled_requests = []
        dynamic_output = output_len is None
1034
        ind = 0
1035
1036
1037
1038

        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
1039
            prompt, completion = item["problem"], item["solution"]
1040
1041
1042
1043
1044
1045
1046

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
1047
1048
1049
            if dynamic_output and not is_valid_sequence(
                prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
            ):
1050
1051
1052
1053
1054
1055
1056
                continue
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=None,
1057
                    request_id=request_id_prefix + str(ind),
1058
1059
                )
            )
1060
1061
1062
1063
            ind += 1
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix
        )
1064
        return sampled_requests
1065
1066


1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------


zeta_prompt = """### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.

### User Edits:

{}

### User Excerpt:

{}

### Response:

1085
"""  # noqa: E501
1086
1087
1088


def _format_zeta_prompt(
1089
1090
    sample: dict, original_start_marker: str = "<|editable_region_start|>"
) -> dict:
1091
    """Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
1092
1093
1094

    This function formats examples from the NEP dataset
    into prompts and expected outputs. It could be
1095
    further extended to support more NEP datasets.
1096

1097
    Args:
1098
        sample: The dataset sample containing events,
1099
            inputs, and outputs.
1100
1101
        original_start_marker: The marker indicating the
            start of the editable region. Defaults to
1102
            "<|editable_region_start|>".
1103

1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    Returns:
        A dictionary with the formatted prompts and expected outputs.
    """
    events = sample["events"]
    input = sample["input"]
    output = sample["output"]
    prompt = zeta_prompt.format(events, input)

    # following the original implementation, extract the focused region
    # from the raw output
    output_start_index = output.find(original_start_marker)
    output_focused_region = output[output_start_index:]
    expected_output = output_focused_region

    return {"prompt": prompt, "expected_output": expected_output}


class NextEditPredictionDataset(HuggingFaceDataset):
    """
    Dataset class for processing a Next Edit Prediction dataset.
    """

    SUPPORTED_DATASET_PATHS = {
        "zed-industries/zeta",
    }
    MAPPING_PROMPT_FUNCS = {
        "zed-industries/zeta": _format_zeta_prompt,
    }

1133
1134
1135
1136
1137
1138
1139
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        request_id_prefix: str = "",
        **kwargs,
    ):
1140
        formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
1141
1142
1143
        if formatting_prompt_func is None:
            raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
        samples = []
1144
        for i, sample in enumerate(self.data):
1145
1146
1147
1148
1149
1150
            sample = formatting_prompt_func(sample)
            samples.append(
                SampleRequest(
                    prompt=sample["prompt"],
                    prompt_len=len(tokenizer(sample["prompt"]).input_ids),
                    expected_output_len=len(
1151
1152
                        tokenizer(sample["expected_output"]).input_ids
                    ),
1153
                    request_id=request_id_prefix + str(i),
1154
1155
                )
            )
1156
1157
            if len(samples) >= num_requests:
                break
1158
        self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
1159
1160
1161
        return samples


1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------


class ASRDataset(HuggingFaceDataset):
    """
    Dataset class for processing a ASR dataset for transcription.
    Tested on the following set:

    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | Dataset        | Domain                                 | Speaking Style           | hf-subset                   |
    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | TED-LIUM       | TED talks                              | Oratory                  | release1, release2, release3|
    |                |                                        |                          | release3-speaker-adaptation |
    | VoxPopuli      | European Parliament                    | Oratory                  | en, de, it, fr,  ...        |
    | LibriSpeech    | Audiobook                              | Narrated                 | "LIUM/tedlium"              |
    | GigaSpeech     | Audiobook, podcast, YouTube            | Narrated, spontaneous    | xs, s, m, l, xl, dev, test  |
    | SPGISpeech     | Financial meetings                     | Oratory, spontaneous     | S, M, L, dev, test          |
    | AMI            | Meetings                               | Spontaneous              | ihm, sdm                    |
    +----------------+----------------------------------------+--------------------------+-----------------------------+

1184
1185
    """  # noqa: E501

1186
    SUPPORTED_DATASET_PATHS = {
1187
1188
1189
1190
1191
1192
        "openslr/librispeech_asr",
        "facebook/voxpopuli",
        "LIUM/tedlium",
        "edinburghcstr/ami",
        "speechcolab/gigaspeech",
        "kensho/spgispeech",
1193
1194
1195
1196
1197
1198
    }

    DEFAULT_OUTPUT_LEN = 128
    IS_MULTIMODAL = True

    # TODO Whisper-specific. Abstract interface when more models are supported.
1199
    TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
1200
1201
1202
1203
1204
1205
1206
    skip_long_audios: bool = True

    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
1207
        request_id_prefix: str = "",
1208
1209
1210
        **kwargs,
    ) -> list:
        import librosa
1211
1212

        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
1213
1214
1215
1216
        prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
        prompt_len = len(tokenizer(prompt).input_ids)
        sampled_requests = []
        skipped = 0
1217
        ind = 0
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
            audio = item["audio"]
            y, sr = audio["array"], audio["sampling_rate"]
            duration_s = librosa.get_duration(y=y, sr=sr)
            # Whisper max supported duration
            if self.skip_long_audios and duration_s > 30:
                skipped += 1
                continue

            mm_content = {"audio": (y, sr)}
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
1236
                    request_id=request_id_prefix + str(ind),
1237
1238
                )
            )
1239
            ind += 1
1240
        if skipped:
1241
1242
1243
1244
1245
1246
            logger.warning(
                "%d samples discarded from dataset due to"
                " their length being greater than"
                " what Whisper supports.",
                skipped,
            )
1247
1248
1249
        self.maybe_oversample_requests(
            sampled_requests, num_requests, request_id_prefix
        )
1250
        return sampled_requests