"csrc/opt/activation_kernels_opt.cu" did not exist on "8ce9c50d4034de3c557b520935fac1d6dac585a0"
benchmark_dataset.py 38.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# SPDX-License-Identifier: Apache-2.0
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
  - ShareGPT
  - Random (synthetic)
  - Sonnet
  - BurstGPT
  - HuggingFace
  - VisionArena

TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT.
"""

import base64
import io
import json
20
import logging
21
22
23
24
25
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from functools import cache
26
27
from io import BytesIO
from typing import Any, Callable, Optional, Union
28
29
30
31
32
33
34
35
36
37

import numpy as np
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import PreTrainedTokenizerBase

from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
38
from vllm.multimodal.image import convert_image_mode
39
40
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer

41
42
logger = logging.getLogger(__name__)

43
44
45
46
47
48
49
50
51
52
53
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------


@dataclass
class SampleRequest:
    """
    Represents a single inference request for benchmarking.
    """

54
    prompt: Union[str, Any]
55
56
57
58
59
60
61
62
63
64
65
66
67
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
    lora_request: Optional[LoRARequest] = None


# -----------------------------------------------------------------------------
# Benchmark Dataset Base Class
# -----------------------------------------------------------------------------


class BenchmarkDataset(ABC):
    DEFAULT_SEED = 0
68
    IS_MULTIMODAL = False
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    def __init__(
        self,
        dataset_path: Optional[str] = None,
        random_seed: int = DEFAULT_SEED,
    ) -> None:
        """
        Initialize the BenchmarkDataset with an optional dataset path and random
        seed.  Args:
            dataset_path (Optional[str]): Path to the dataset. If None, it
            indicates that a default or random dataset might be used.
            random_seed (int): Seed value for reproducible shuffling or
            sampling. Defaults to DEFAULT_SEED.
        """
        self.dataset_path = dataset_path
        # Set the random seed, ensuring that a None value is replaced with the
        # default seed.
86
        self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
87
88
        self.data = None

89
    def apply_multimodal_chat_transformation(
90
91
        self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
    ) -> list[dict]:
92
93
        """
        Transform a prompt and optional multimodal content into a chat format.
94
95
        This method is used for chat models that expect a specific conversation
        format.
96
97
98
99
100
101
        """
        content = [{"text": prompt, "type": "text"}]
        if mm_content is not None:
            content.append(mm_content)
        return [{"role": "user", "content": content}]

102
103
104
    def load_data(self) -> None:
        """
        Load data from the dataset path into self.data.
105

106
107
        This method must be overridden by subclasses since the method to load
        data will vary depending on the dataset format and source.
108

109
110
111
112
        Raises:
            NotImplementedError: If a subclass does not implement this method.
        """
        # TODO (jenniferzhao): add support for downloading data
113
        raise NotImplementedError("load_data must be implemented in subclasses.")
114
115
116
117
118
119
120
121
122
123

    def get_random_lora_request(
        self,
        tokenizer: PreTrainedTokenizerBase,
        max_loras: Optional[int] = None,
        lora_path: Optional[str] = None,
    ) -> tuple[Optional[LoRARequest], AnyTokenizer]:
        """
        Optionally select a random LoRA request and return its associated
        tokenizer.
124

125
126
127
        This method is used when LoRA parameters are provided.  It randomly
        selects a LoRA based on max_loras and retrieves a cached tokenizer for
        that LoRA if available. Otherwise, it returns the base tokenizer.
128

129
130
131
132
133
134
        Args:
            tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
            LoRA is selected.  max_loras (Optional[int]): The maximum number of
            LoRAs available. If None, LoRA is not used.  lora_path
            (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
            is not used.
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        Returns:
            tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
            element is a LoRARequest (or None if not applicable) and the second
            element is the tokenizer associated with the LoRA request (or the
            base tokenizer).
        """
        if max_loras is None or lora_path is None:
            return None, tokenizer

        # Generate a random LoRA ID in the range [1, max_loras].
        lora_id = random.randint(1, max_loras)
        lora_request = LoRARequest(
            lora_name=str(lora_id),
            lora_int_id=lora_id,
            lora_path=lora_path_on_disk(lora_path),
        )
        if lora_id not in lora_tokenizer_cache:
            lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
        # Return lora_request and the cached tokenizer if available; otherwise,
        # return the base tokenizer
        return lora_request, lora_tokenizer_cache[lora_id] or tokenizer

    @abstractmethod
159
160
161
    def sample(
        self, tokenizer: PreTrainedTokenizerBase, num_requests: int
    ) -> list[SampleRequest]:
162
163
        """
        Abstract method to generate sample requests from the dataset.
164

165
166
        Subclasses must override this method to implement dataset-specific logic
        for generating a list of SampleRequest objects.
167

168
169
170
171
        Args:
            tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
             for processing the dataset's text.
            num_requests (int): The number of sample requests to generate.
172

173
174
175
176
177
178
        Returns:
            list[SampleRequest]: A list of sample requests generated from the
            dataset.
        """
        raise NotImplementedError("sample must be implemented in subclasses.")

179
180
181
    def maybe_oversample_requests(
        self, requests: list[SampleRequest], num_requests: int
    ) -> None:
182
183
184
185
186
187
188
189
190
191
        """
        Oversamples the list of requests if its size is less than the desired
        number.

        Args:
            requests (List[SampleRequest]): The current list of sampled
            requests.  num_requests (int): The target number of requests.
        """
        if len(requests) < num_requests:
            random.seed(self.random_seed)
192
            additional = random.choices(requests, k=num_requests - len(requests))
193
            requests.extend(additional)
194
            logger.info("Oversampled requests to reach %d total samples.", num_requests)
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

# -----------------------------------------------------------------------------
# Utility Functions and Global Caches
# -----------------------------------------------------------------------------


def is_valid_sequence(
    prompt_len: int,
    output_len: int,
    min_len: int = 4,
    max_prompt_len: int = 1024,
    max_total_len: int = 2048,
    skip_min_output_len_check: bool = False,
) -> bool:
    """
    Validate a sequence based on prompt and output lengths.

    Default pruning criteria are copied from the original `sample_hf_requests`
    and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
    from `sample_requests` in benchmark_throughput.py.
    """
    # Check for invalid conditions
    prompt_too_short = prompt_len < min_len
219
    output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
220
221
222
223
    prompt_too_long = prompt_len > max_prompt_len
    combined_too_long = (prompt_len + output_len) > max_total_len

    # Return True if none of the invalid conditions are met
224
225
226
    return not (
        prompt_too_short or output_too_short or prompt_too_long or combined_too_long
    )
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241


@cache
def lora_path_on_disk(lora_path: str) -> str:
    return get_adapter_absolute_path(lora_path)


# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}


def process_image(image: Any) -> Mapping[str, Any]:
    """
    Process a single image input and return a multimedia content dictionary.

242
    Supports three input types:
243

244
245
246
247
248
249
250
251
252
253
    1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
       containing raw image data.  - Loads the bytes as a PIL.Image.Image.

    2. PIL.Image.Image input: - Converts the image to RGB.  - Saves the image as
       a JPEG in memory.  - Encodes the JPEG data as a base64 string.  - Returns
       a dictionary with the image as a base64 data URL.

    3. String input: - Treats the string as a URL or local file path.  -
       Prepends "file://" if the string doesn't start with "http://" or
       "file://".  - Returns a dictionary with the image URL.
254
255

    Raises:
256
        ValueError: If the input is not a supported type.
257
    """
258
259
    if isinstance(image, dict) and "bytes" in image:
        image = Image.open(BytesIO(image["bytes"]))
260
    if isinstance(image, Image.Image):
261
        image = convert_image_mode(image, "RGB")
262
263
        with io.BytesIO() as image_data:
            image.save(image_data, format="JPEG")
264
            image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
265
266
        return {
            "type": "image_url",
267
            "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
268
269
270
        }

    if isinstance(image, str):
271
272
273
        image_url = (
            image if image.startswith(("http://", "file://")) else f"file://{image}"
        )
274
275
        return {"type": "image_url", "image_url": {"url": image_url}}

276
277
278
279
    raise ValueError(
        f"Invalid image input {image}. Must be a PIL.Image.Image"
        " or str or dictionary with raw image bytes."
    )
280
281
282
283
284
285
286
287
288
289


# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------


class RandomDataset(BenchmarkDataset):
    # Default values copied from benchmark_serving.py for the random dataset.
    DEFAULT_PREFIX_LEN = 0
290
    DEFAULT_RANGE_RATIO = 0.0
291
292
293
294
295
296
297
298
299
    DEFAULT_INPUT_LEN = 1024
    DEFAULT_OUTPUT_LEN = 128

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

300
301
302
303
304
305
306
307
308
309
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        range_ratio: float = DEFAULT_RANGE_RATIO,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
        **kwargs,
    ) -> list[SampleRequest]:
310
311
312
313
314
        # Enforce range_ratio < 1
        assert range_ratio < 1.0, (
            "random_range_ratio must be < 1.0 to ensure a valid sampling range"
        )

315
        vocab_size = tokenizer.vocab_size
316
317
        num_special_tokens = tokenizer.num_special_tokens_to_add()
        real_input_len = input_len - num_special_tokens
318

319
320
321
322
323
        prefix_token_ids = (
            np.random.randint(0, vocab_size, size=prefix_len).tolist()
            if prefix_len > 0
            else []
        )
324

325
        # New sampling logic: [X * (1 - b), X * (1 + b)]
326
327
        input_low = int(real_input_len * (1 - range_ratio))
        input_high = int(real_input_len * (1 + range_ratio))
328
329
330
331
332
        output_low = int(output_len * (1 - range_ratio))
        output_high = int(output_len * (1 + range_ratio))

        # Add logging for debugging
        logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
333
334
335
336
        logger.info("Sampling output_len from [%s, %s]", output_low, output_high)

        input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
        output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
337
338
339
340
        offsets = np.random.randint(0, vocab_size, size=num_requests)

        requests = []
        for i in range(num_requests):
341
342
343
            inner_seq = (
                (offsets[i] + i + np.arange(input_lens[i])) % vocab_size
            ).tolist()
344
345
            token_sequence = prefix_token_ids + inner_seq
            prompt = tokenizer.decode(token_sequence)
346
347
348
349
350
351
352
353
            # After decoding the prompt we have to encode and decode it again.
            # This is done because in some cases N consecutive tokens
            # give a string tokenized into != N number of tokens.
            # For example for GPT2Tokenizer:
            # [6880, 6881] -> ['Ġcalls', 'here'] ->
            # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
            # To avoid uncontrolled change of the prompt length,
            # the encoded sequence is truncated before being decode again.
354
355
356
            re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
                : input_lens[i]
            ]
357
            prompt = tokenizer.decode(re_encoded_sequence)
358
359
360
361
362
363
            total_input_len = prefix_len + int(input_lens[i])
            requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=total_input_len,
                    expected_output_len=int(output_lens[i]),
364
365
                )
            )
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        return requests


# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
# -----------------------------------------------------------------------------


class ShareGPTDataset(BenchmarkDataset):
    """
    Implements the ShareGPT dataset.  Loads data from a JSON file and generates
    sample requests based on conversation turns.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = json.load(f)
        # Filter entries with at least two conversation turns.
        self.data = [
392
393
            entry
            for entry in self.data
394
395
396
397
398
            if "conversations" in entry and len(entry["conversations"]) >= 2
        ]
        random.seed(self.random_seed)
        random.shuffle(self.data)

399
400
401
402
403
404
405
406
407
408
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        lora_path: Optional[str] = None,
        max_loras: Optional[int] = None,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
409
410
411
412
        samples: list = []
        for entry in self.data:
            if len(samples) >= num_requests:
                break
413
414
415
416
            prompt, completion = (
                entry["conversations"][0]["value"],
                entry["conversations"][1]["value"],
            )
417
418

            lora_request, tokenizer = self.get_random_lora_request(
419
420
                tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
            )
421
422
423
            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
424
425
426
427
428
429
            new_output_len = len(completion_ids) if output_len is None else output_len
            if not is_valid_sequence(
                prompt_len,
                new_output_len,
                skip_min_output_len_check=output_len is not None,
            ):
430
                continue
431
            if enable_multimodal_chat:
432
                prompt = self.apply_multimodal_chat_transformation(prompt, None)
433
434
435
436
437
438
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=new_output_len,
                    lora_request=lora_request,
439
440
                )
            )
441
        self.maybe_oversample_requests(samples, num_requests)
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        return samples


# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------


class SonnetDataset(BenchmarkDataset):
    """
    Simplified implementation of the Sonnet dataset.  Loads poem lines from a
    text file and generates sample requests.  Default values here copied from
    `benchmark_serving.py` for the sonnet dataset.
    """

    DEFAULT_PREFIX_LEN = 200
    DEFAULT_INPUT_LEN = 550
    DEFAULT_OUTPUT_LEN = 150

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.load_data()

    def load_data(self) -> None:
        if not self.dataset_path:
            raise ValueError("dataset_path must be provided.")
        with open(self.dataset_path, encoding="utf-8") as f:
            self.data = f.readlines()

474
475
476
477
478
479
480
481
482
483
    def sample(
        self,
        tokenizer,
        num_requests: int,
        prefix_len: int = DEFAULT_PREFIX_LEN,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
        return_prompt_formatted: bool = False,
        **kwargs,
    ) -> list:
484
485
        # Calculate average token length for a poem line.
        tokenized_lines = [tokenizer(line).input_ids for line in self.data]
486
        avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
487
488
489
490

        # Build the base prompt.
        base_prompt = "Pick as many lines as you can from these poem lines:\n"
        base_msg = [{"role": "user", "content": base_prompt}]
491
492
493
        base_fmt = tokenizer.apply_chat_template(
            base_msg, add_generation_prompt=True, tokenize=False
        )
494
495
496
497
        base_offset = len(tokenizer(base_fmt).input_ids)
        if input_len <= base_offset:
            raise ValueError(
                f"'input_len' must be higher than the base prompt length "
498
499
                f"({base_offset})."
            )
500
501
502

        # Determine how many poem lines to use.
        num_input_lines = round((input_len - base_offset) / avg_len)
503
        num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
504
505
506
        prefix_lines = self.data[:num_prefix_lines]

        samples = []
507
        while len(samples) < num_requests:
508
509
510
            extra_lines = random.choices(
                self.data, k=num_input_lines - num_prefix_lines
            )
511
512
513
            prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
            msg = [{"role": "user", "content": prompt}]
            prompt_formatted = tokenizer.apply_chat_template(
514
515
                msg, add_generation_prompt=True, tokenize=False
            )
516
            prompt_len = len(tokenizer(prompt_formatted).input_ids)
517
518
519
            if prompt_len <= input_len:
                samples.append(
                    SampleRequest(
520
                        prompt=prompt_formatted if return_prompt_formatted else prompt,
521
522
                        prompt_len=prompt_len,
                        expected_output_len=output_len,
523
524
                    )
                )
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        return samples


# -----------------------------------------------------------------------------
# BurstGPT Dataset Implementation
# -----------------------------------------------------------------------------


class BurstGPTDataset(BenchmarkDataset):
    """
    Implements the BurstGPT dataset.  Loads data from a CSV file and generates
    sample requests based on synthetic prompt generation. Only rows with Model
    "GPT-4" and positive response tokens are used.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.load_data()

544
545
546
    def load_data(
        self,
    ):
547
548
549
550
551
552
553
554
555
556
557
558
559
        if self.dataset_path is None:
            raise ValueError("dataset_path must be provided for loading data.")

        df = pd.read_csv(self.dataset_path)
        # Filter to keep only GPT-4 rows.
        gpt4_df = df[df["Model"] == "GPT-4"]
        # Remove failed requests (where Response tokens is 0 or less).
        gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
        # Sample the desired number of rows.
        self.data = gpt4_df

    def _sample_loaded_data(self, num_requests: int) -> list:
        if num_requests <= len(self.data):
560
            data = self.data.sample(n=num_requests, random_state=self.random_seed)
561
562
563
564
565
566
567
568
569
        else:
            data = self.data.sample(
                n=num_requests,
                random_state=self.random_seed,
                replace=True,
            )
        # Convert the dataframe to a list of lists.
        return data.values.tolist()

570
571
572
573
574
575
576
577
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        max_loras: Optional[int] = None,
        lora_path: Optional[str] = None,
        **kwargs,
    ) -> list[SampleRequest]:
578
579
580
581
582
583
        samples = []
        data = self._sample_loaded_data(num_requests=num_requests)
        for i in range(num_requests):
            input_len = int(data[i][2])
            output_len = int(data[i][3])
            lora_req, tokenizer = self.get_random_lora_request(
584
585
                tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
            )
586
587
588
589
590
591
592
593
594
595
596
            vocab_size = tokenizer.vocab_size
            # Generate a synthetic prompt: a list of token IDs computed as (i +
            # j) modulo vocab_size.
            token_ids = [(i + j) % vocab_size for j in range(input_len)]
            prompt = tokenizer.decode(token_ids)
            samples.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=input_len,
                    expected_output_len=output_len,
                    lora_request=lora_req,
597
598
                )
            )
599
600
601
602
        return samples


# -----------------------------------------------------------------------------
603
# HuggingFace Dataset Base Implementation
604
605
# -----------------------------------------------------------------------------
class HuggingFaceDataset(BenchmarkDataset):
606
607
608
    """Base class for datasets hosted on HuggingFace."""

    SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
609
610
611

    def __init__(
        self,
612
        dataset_path: str,
613
614
615
616
        dataset_split: str,
        dataset_subset: Optional[str] = None,
        **kwargs,
    ) -> None:
617
618
        super().__init__(dataset_path=dataset_path, **kwargs)

619
620
621
622
623
        self.dataset_split = dataset_split
        self.dataset_subset = dataset_subset
        self.load_data()

    def load_data(self) -> None:
624
        """Load data from HuggingFace datasets."""
625
626
627
628
629
630
        self.data = load_dataset(
            self.dataset_path,
            name=self.dataset_subset,
            split=self.dataset_split,
            streaming=True,
        )
631
632
633
634
635
636
637
638
639
640
        self.data = self.data.shuffle(seed=self.random_seed)


# -----------------------------------------------------------------------------
# Conversation Dataset Implementation
# -----------------------------------------------------------------------------


class ConversationDataset(HuggingFaceDataset):
    """Dataset for conversation data with multimodal support."""
641

642
    SUPPORTED_DATASET_PATHS = {
643
644
        "lmms-lab/LLaVA-OneVision-Data",
        "Aeala/ShareGPT_Vicuna_unfiltered",
645
    }
646
    IS_MULTIMODAL = True
647

648
649
650
651
652
653
654
655
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
656
        # Filter examples with at least 2 conversations
657
        filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
658
659
660
        sampled_requests = []
        dynamic_output = output_len is None

661
        for item in filtered_data:
662
663
664
665
666
667
668
669
670
671
672
            if len(sampled_requests) >= num_requests:
                break
            conv = item["conversations"]
            prompt, completion = conv[0]["value"], conv[1]["value"]

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
673
            if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
674
                continue
675
            mm_content = process_image(item["image"]) if "image" in item else None
676
677
678
679
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len and output len
680
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
681
682
683
684
685
686
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
687
688
                )
            )
689
        self.maybe_oversample_requests(sampled_requests, num_requests)
690
691
692
693
694
695
696
697
        return sampled_requests


# -----------------------------------------------------------------------------
# Vision Arena Dataset Implementation
# -----------------------------------------------------------------------------


698
class VisionArenaDataset(HuggingFaceDataset):
699
700
701
702
703
    """
    Vision Arena Dataset.
    """

    DEFAULT_OUTPUT_LEN = 128
704
    SUPPORTED_DATASET_PATHS = {
705
706
        "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
        "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
707
    }
708
    IS_MULTIMODAL = True
709

710
711
712
713
714
715
716
717
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
718
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
719
720
721
722
        sampled_requests = []
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
723
724
            parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
            if parser_fn is None:
725
                raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
726
            prompt = parser_fn(item)
727
            mm_content = process_image(item["images"][0])
728
729
730
731
732
            prompt_len = len(tokenizer(prompt).input_ids)
            if enable_multimodal_chat:
                # Note: when chat is enabled the request prompt_len is no longer
                # accurate and we will be using request output to count the
                # actual prompt len
733
                prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
734
735
736
737
738
739
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
740
741
                )
            )
742
        self.maybe_oversample_requests(sampled_requests, num_requests)
743
        return sampled_requests
744
745
746
747
748
749
750
751
752
753
754
755


# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------


class InstructCoderDataset(HuggingFaceDataset):
    """
    InstructCoder Dataset.
    https://huggingface.co/datasets/likaixin/InstructCoder

756
757
758
    InstructCoder is the dataset designed for general code editing.  It consists
    of 114,239 instruction-input-output triplets, and covers multiple distinct
    code editing scenario.
759
760
761
    """

    DEFAULT_OUTPUT_LEN = 200  # this is the average default output length
762
763
764
    SUPPORTED_DATASET_PATHS = {
        "likaixin/InstructCoder",
    }
765

766
767
768
769
770
771
772
773
774
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
775
776
777
778
779
780
781
782
783
784
785
        sampled_requests = []
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
            prompt = f"{item['instruction']}:\n{item['input']}"
            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
786
787
                )
            )
788
789
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests
790
791


792
793
794
795
796
797
798
799
800
801
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------


class MTBenchDataset(HuggingFaceDataset):
    """
    MT-Bench Dataset.
    https://huggingface.co/datasets/philschmid/mt-bench

802
    We create a single turn dataset for MT-Bench.
803
804
    This is similar to Spec decoding benchmark setup in vLLM
    https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
805
    """  # noqa: E501
806
807
808
809
810
811

    DEFAULT_OUTPUT_LEN = 256  # avg len used in SD bench in vLLM
    SUPPORTED_DATASET_PATHS = {
        "philschmid/mt-bench",
    }

812
813
814
815
816
817
818
819
820
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        enable_multimodal_chat: bool = False,
        **kwargs,
    ) -> list:
        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
821
822
823
824
825
        sampled_requests = []

        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
826
            prompt = item["turns"][0]
827
828

            # apply template
829
830
831
832
833
            prompt = tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
834
835
836
837
838
839
840

            prompt_len = len(tokenizer(prompt).input_ids)
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
841
842
                )
            )
843
844
845
846
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests


847
848
849
850
851
852
853
854
855
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------


class AIMODataset(HuggingFaceDataset):
    """
    Dataset class for processing a AIMO dataset with reasoning questions.
    """
856

857
    SUPPORTED_DATASET_PATHS = {
858
859
860
        "AI-MO/aimo-validation-aime",
        "AI-MO/NuminaMath-1.5",
        "AI-MO/NuminaMath-CoT",
861
862
    }

863
864
865
866
867
868
869
    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        **kwargs,
    ) -> list:
870
871
872
873
874
875
        sampled_requests = []
        dynamic_output = output_len is None

        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
876
            prompt, completion = item["problem"], item["solution"]
877
878
879
880
881
882
883

            prompt_ids = tokenizer(prompt).input_ids
            completion_ids = tokenizer(completion).input_ids
            prompt_len = len(prompt_ids)
            completion_len = len(completion_ids)
            output_len = completion_len if dynamic_output else output_len
            assert isinstance(output_len, int) and output_len > 0
884
885
886
            if dynamic_output and not is_valid_sequence(
                prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
            ):
887
888
889
890
891
892
893
                continue
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=None,
894
895
                )
            )
896
897
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests
898
899


900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------


zeta_prompt = """### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.

### User Edits:

{}

### User Excerpt:

{}

### Response:

918
"""  # noqa: E501
919
920
921


def _format_zeta_prompt(
922
923
    sample: dict, original_start_marker: str = "<|editable_region_start|>"
) -> dict:
924
    """Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
925
926
927

    This function formats examples from the NEP dataset
    into prompts and expected outputs. It could be
928
    further extended to support more NEP datasets.
929

930
    Args:
931
        sample: The dataset sample containing events,
932
            inputs, and outputs.
933
934
        original_start_marker: The marker indicating the
            start of the editable region. Defaults to
935
            "<|editable_region_start|>".
936

937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
    Returns:
        A dictionary with the formatted prompts and expected outputs.
    """
    events = sample["events"]
    input = sample["input"]
    output = sample["output"]
    prompt = zeta_prompt.format(events, input)

    # following the original implementation, extract the focused region
    # from the raw output
    output_start_index = output.find(original_start_marker)
    output_focused_region = output[output_start_index:]
    expected_output = output_focused_region

    return {"prompt": prompt, "expected_output": expected_output}


class NextEditPredictionDataset(HuggingFaceDataset):
    """
    Dataset class for processing a Next Edit Prediction dataset.
    """

    SUPPORTED_DATASET_PATHS = {
        "zed-industries/zeta",
    }
    MAPPING_PROMPT_FUNCS = {
        "zed-industries/zeta": _format_zeta_prompt,
    }

966
967
    def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
        formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
968
969
970
971
972
973
974
975
976
977
        if formatting_prompt_func is None:
            raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
        samples = []
        for sample in self.data:
            sample = formatting_prompt_func(sample)
            samples.append(
                SampleRequest(
                    prompt=sample["prompt"],
                    prompt_len=len(tokenizer(sample["prompt"]).input_ids),
                    expected_output_len=len(
978
979
980
981
                        tokenizer(sample["expected_output"]).input_ids
                    ),
                )
            )
982
983
984
985
986
987
            if len(samples) >= num_requests:
                break
        self.maybe_oversample_requests(samples, num_requests)
        return samples


988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------


class ASRDataset(HuggingFaceDataset):
    """
    Dataset class for processing a ASR dataset for transcription.
    Tested on the following set:

    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | Dataset        | Domain                                 | Speaking Style           | hf-subset                   |
    +----------------+----------------------------------------+--------------------------+-----------------------------+
    | TED-LIUM       | TED talks                              | Oratory                  | release1, release2, release3|
    |                |                                        |                          | release3-speaker-adaptation |
    | VoxPopuli      | European Parliament                    | Oratory                  | en, de, it, fr,  ...        |
    | LibriSpeech    | Audiobook                              | Narrated                 | "LIUM/tedlium"              |
    | GigaSpeech     | Audiobook, podcast, YouTube            | Narrated, spontaneous    | xs, s, m, l, xl, dev, test  |
    | SPGISpeech     | Financial meetings                     | Oratory, spontaneous     | S, M, L, dev, test          |
    | AMI            | Meetings                               | Spontaneous              | ihm, sdm                    |
    +----------------+----------------------------------------+--------------------------+-----------------------------+

1010
1011
    """  # noqa: E501

1012
    SUPPORTED_DATASET_PATHS = {
1013
1014
1015
1016
1017
1018
        "openslr/librispeech_asr",
        "facebook/voxpopuli",
        "LIUM/tedlium",
        "edinburghcstr/ami",
        "speechcolab/gigaspeech",
        "kensho/spgispeech",
1019
1020
1021
1022
1023
1024
    }

    DEFAULT_OUTPUT_LEN = 128
    IS_MULTIMODAL = True

    # TODO Whisper-specific. Abstract interface when more models are supported.
1025
    TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    skip_long_audios: bool = True

    def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        output_len: Optional[int] = None,
        **kwargs,
    ) -> list:
        import librosa
1036
1037

        output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
        prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
        prompt_len = len(tokenizer(prompt).input_ids)
        sampled_requests = []
        skipped = 0
        for item in self.data:
            if len(sampled_requests) >= num_requests:
                break
            audio = item["audio"]
            y, sr = audio["array"], audio["sampling_rate"]
            duration_s = librosa.get_duration(y=y, sr=sr)
            # Whisper max supported duration
            if self.skip_long_audios and duration_s > 30:
                skipped += 1
                continue

            mm_content = {"audio": (y, sr)}
            sampled_requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=prompt_len,
                    expected_output_len=output_len,
                    multi_modal_data=mm_content,
1060
1061
                )
            )
1062
        if skipped:
1063
1064
1065
1066
1067
1068
            logger.warning(
                "%d samples discarded from dataset due to"
                " their length being greater than"
                " what Whisper supports.",
                skipped,
            )
1069
1070
        self.maybe_oversample_requests(sampled_requests, num_requests)
        return sampled_requests