Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
......@@ -18,7 +18,7 @@ class BeamSearchSequence:
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
# The tokens include the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
lora_request: Optional[LoRARequest] = None
......
......@@ -11,17 +11,21 @@ generation. Supported dataset types include:
- HuggingFace
- VisionArena
"""
import ast
import base64
import io
import json
import logging
import math
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
from collections.abc import Iterator, Mapping
from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass
from functools import cache
from io import BytesIO
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast
import numpy as np
from PIL import Image
......@@ -69,13 +73,14 @@ class SampleRequest:
Represents a single inference request for benchmarking.
"""
prompt: Union[str, Any]
prompt: Union[str, list[str]]
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[
Union[MultiModalDataDict, dict, list[dict]]
] = None
lora_request: Optional[LoRARequest] = None
request_id: Optional[str] = None
# -----------------------------------------------------------------------------
......@@ -112,7 +117,9 @@ class BenchmarkDataset(ABC):
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
mm_content: Optional[
Union[MultiModalDataDict, dict, list[dict]]
] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
......@@ -120,7 +127,15 @@ class BenchmarkDataset(ABC):
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
if isinstance(mm_content, list):
content.extend(cast(list[dict[str, Any]], mm_content))
elif isinstance(mm_content, dict):
content.append(mm_content)
else:
raise TypeError(
"Could not process multimodal content of type: " +
f"{type(mm_content)}"
)
return [{"role": "user", "content": content}]
def load_data(self) -> None:
......@@ -183,7 +198,8 @@ class BenchmarkDataset(ABC):
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]:
num_requests: int,
request_id_prefix: str = "") -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
......@@ -194,6 +210,8 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns:
list[SampleRequest]: A list of sample requests generated from the
......@@ -201,8 +219,12 @@ class BenchmarkDataset(ABC):
"""
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest],
num_requests: int) -> None:
def maybe_oversample_requests(
self,
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
......@@ -211,11 +233,17 @@ class BenchmarkDataset(ABC):
requests (List[SampleRequest]): The current list of sampled
requests.
num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = random.choices(requests,
k=num_requests - len(requests))
additional = deepcopy(
random.choices(requests, k=num_requests - len(requests))
)
for i in range(len(additional)):
req = additional[i]
req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
......@@ -266,7 +294,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
"""
Process a single image input and return a multimedia content dictionary.
Supports three input types:
Supports the following input types:
1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
containing raw image data. - Loads the bytes as a PIL.Image.Image.
......@@ -306,94 +334,592 @@ def process_image(image: Any) -> Mapping[str, Any]:
" or str or dictionary with raw image bytes.")
def process_video(video: Any) -> Mapping[str, Any]:
"""
Process a single video input and return a multimedia content dictionary.
Supports the following input types:
1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
containing raw video data.
2. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(video, dict) and 'bytes' in video:
video_bytes = video['bytes']
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
return {
"type": "video_url",
"video_url": {
"url": f"data:video/mp4;base64,{video_base64}"
},
}
if isinstance(video, str):
video_url = (video if video.startswith(
("http://", "file://")) else f"file://{video}")
return {"type": "video_url", "video_url": {"url": video_url}}
raise ValueError(
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
)
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset(BenchmarkDataset):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
Strategy:
- Sample input/output token lengths per request from integer-uniform ranges
around configured means (controlled by range_ratio).
- Prepend a fixed random prefix of length prefix_len.
- Generate the remaining tokens as a reproducible sequence:
(offset + index + arange(input_len)) % vocab_size.
- Decode then re-encode/truncate to ensure prompt token counts match.
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
"""
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN = 0
DEFAULT_RANGE_RATIO = 0.0
DEFAULT_INPUT_LEN = 1024
DEFAULT_OUTPUT_LEN = 128
def __init__(
self,
**kwargs,
) -> None:
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
random.seed(self.random_seed)
np.random.seed(self.random_seed)
# Use numpy's default_rng for deterministic sampling
# Do not use random.seed() or np.random.seed() elsewhere in this class.
# This ensures that the RNG is isolated from global RNG state.
self._rng = np.random.default_rng(self.random_seed)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
batchsize: int = 1,
**kwargs,
) -> list[SampleRequest]:
# Enforce range_ratio < 1
assert range_ratio < 1.0, (
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
requests = []
for i in range(num_requests):
prompt, total_input_len = self.generate_token_sequence(
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
)
)
# only used for embeddings benchmark.
if batchsize > 1:
batch_requests = []
# Create batched requests
for i in range(0, num_requests, batchsize):
batch = requests[i : i + batchsize]
batch_requests.append(
SampleRequest(
prompt=[req.prompt for req in batch],
prompt_len=sum(req.prompt_len for req in batch),
expected_output_len=0,
request_id=request_id_prefix + str(i // batchsize),
)
)
requests = batch_requests
return requests
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
def get_prefix(
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(
0, tokenizer.vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
def get_sampling_params(
self,
num_requests: int,
range_ratio: float,
input_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if not (0.0 <= range_ratio < 1.0):
raise ValueError("range_ratio must be in [0, 1).")
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
real_input_len = max(0, int(input_len) - num_special_tokens)
# Bounds use floor for low and ceil for high
input_low = math.floor(real_input_len * (1 - range_ratio))
input_high = math.ceil(real_input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
if input_low > input_high:
raise ValueError(
"Invalid input sampling interval: "
f"low={input_low} > high={input_high}"
)
if output_low > output_high:
raise ValueError(
"Invalid output sampling interval: "
f"low={output_low} > high={output_high}"
)
# Add logging for debugging
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low, input_high, output_low, output_high)
input_low,
input_high,
output_low,
output_high,
)
input_lens = np.random.randint(input_low,
input_high + 1,
input_lens = self._rng.integers(input_low, input_high + 1,
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
output_lens = self._rng.integers(output_low, output_high + 1,
size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
offsets = self._rng.integers(0, tokenizer.vocab_size,
size=num_requests)
return input_lens, output_lens, offsets
requests = []
for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
vocab_size).tolist()
def generate_token_sequence(
self,
*,
tokenizer: PreTrainedTokenizerBase,
prefix_token_ids: list[int],
prefix_len: int,
vocab_size: int,
input_len: int,
offset: int,
index: int,
) -> tuple[str, int]:
"""
Returns (prompt, total_input_len).
NOTE: After decoding the prompt we have to encode and decode it again.
This is done because in some cases N consecutive tokens
give a string tokenized into != N number of tokens.
For example for GPT2Tokenizer:
[6880, 6881] -> ['Ġcalls', 'here'] ->
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
To avoid uncontrolled change of the prompt length,
the encoded sequence is truncated before being decode again.
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len))
% vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
# Decode, then re-encode and truncate to preserve token count invariants
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
total_input_len = prefix_len + int(input_lens[i])
total_input_len = prefix_len + int(input_len)
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:total_input_len]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
requests.append(
SampleRequest(
return prompt, total_input_len
# -----------------------------------------------------------------------------
# MultiModalDataset Implementation
# -----------------------------------------------------------------------------
class RandomMultiModalDataset(RandomDataset):
"""
Synthetic multimodal dataset (text + images) that extends RandomDataset.
Status:
- Images: supported via synthetic RGB data.
- Video: not yet supported (TODO: implement video generation method).
- Audio: not yet supported.
Sampling overview:
1) Number of items per request is sampled uniformly from the integer range
[floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
The maximum is further clamped to the sum of per-modality limits.
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) → probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
When a modality reaches its cap, all of its buckets are excluded and the
remaining probabilities are renormalized.
Example bucket configuration:
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
- Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16).
OBS.: Only image sampling is supported for now.
"""
IS_MULTIMODAL = True
# NOTE: video sampling is WIP. Setting it to 0.
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
DEFAULT_MM_ITEM_BUCKET_CONFIG = {
(256, 256, 1): 0.5,
(720, 1280, 1): 0.5,
(720, 1280, 16): 0.0,
}
DEFAULT_ENABLE_MULTIMODAL_CHAT = False
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
"""Generate synthetic PIL image with random RGB values.
NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos.
We could consider a “low-freq” mode (e.g., noise blur)
to emulate network realism instead of max stress.
"""
random_pixels = self._rng.integers(
0,
256,
(height, width, 3),
dtype=np.uint8,
)
return Image.fromarray(random_pixels)
def generate_synthetic_video(self, width: int,
height: int,
num_frames: int) -> Any:
"""Generate synthetic video with random values.
TODO: Finish this method.
"""
raise NotImplementedError("Video sampling is WIP.")
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
"""Map the configuration to the modality."""
if config[-1] == 1:
return "image"
elif config[-1] > 1:
return "video"
else:
raise ValueError(f"Invalid multimodal item configuration: {config}")
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
float]) -> dict[tuple[int, int, int], float]:
"""
Remove zero probability entries
and normalize the bucket config to sum to 1.
"""
# Raise error if value is negative
if any(v < 0 for v in bucket_config.values()):
raise ValueError("Bucket config values must be non-negative.")
# Remove zero probability entries
bucket_config = {k: v for k, v in bucket_config.items() if v > 0}
# if bucket config is empty, raise error
if not bucket_config:
raise ValueError("Got invalid bucket config. "
"Bucket config values must be non-zero.")
# Normalize the remaining bucket config to sum to 1
total = sum(bucket_config.values())
return {k: v / total for k, v in bucket_config.items()}
def generate_mm_item(self,
mm_item_config: tuple[int, int, int],
) -> Mapping[str, Any]:
"""
Create synthetic images and videos and
apply process_image/process_video respectively.
This follows the OpenAI API chat completions
https://github.com/openai/openai-python
"""
if self.map_config_to_modality(mm_item_config) == "image":
return process_image(self.generate_synthetic_image(
mm_item_config[1],
mm_item_config[0]))
elif self.map_config_to_modality(mm_item_config) == "video":
return process_video(self.generate_synthetic_video(
mm_item_config[1],
mm_item_config[0],
mm_item_config[2]))
else:
raise ValueError(f"Invalid multimodal item configuration: "
f"{mm_item_config}")
def get_mm_item_sampling_params(
self,
base_items_per_request: int,
num_mm_items_range_ratio: float,
limit_mm_per_prompt: dict[str, int],
bucket_config: dict[tuple[int, int, int], float],
) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]:
"""
Get the sampling parameters for the multimodal items.
"""
# Enforce num_mm_items_range_ratio <= 1
if not (0.0 <= num_mm_items_range_ratio <= 1.0):
raise ValueError("num_mm_items_range_ratio must be in [0, 1].")
# Ensure modalities to sample are in limit_mm_per_prompt
for k, v in bucket_config.items():
# get modality from bucket config
modality = self.map_config_to_modality(k)
if modality not in limit_mm_per_prompt:
raise ValueError(f"Modality {modality} is not in "
f"limit_mm_per_prompt: "
f"{limit_mm_per_prompt.keys()}")
# Remove zero probability entries
# and normalize bucket config to sum to 1
bucket_config = self.normalize_bucket_config(bucket_config)
logger.info(
"Normalized bucket config: %s", bucket_config,
)
# Only consider limit per prompt for modalities in bucket config
allowed_modalities = {self.map_config_to_modality(cfg)
for cfg in bucket_config}
limit_mm_per_prompt = {
k: v for k, v in limit_mm_per_prompt.items()
if k in allowed_modalities}
if not limit_mm_per_prompt:
raise ValueError("No valid limits for modalities present in "
"bucket_config.")
logger.info(
"Updated mm-limit-per-prompt: %s", limit_mm_per_prompt,
)
# Get max and min num mm items and ensure
# it is at most the sum of limit_mm_per_prompt for all modalities
max_num_mm_items = min(
sum(limit_mm_per_prompt.values()),
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
)
# Ensure min num mm items is at least 0
min_num_mm_items = max(
0,
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
)
# Raise error if min num mm items is greater than max num mm items
if min_num_mm_items > max_num_mm_items:
raise ValueError(f"Min num mm items is greater than max mm items: "
f"{min_num_mm_items} > {max_num_mm_items}")
logger.info(
"Sampling number of multimodal items from [%s, %s]",
min_num_mm_items, max_num_mm_items,
)
return (
min_num_mm_items,
max_num_mm_items,
limit_mm_per_prompt,
bucket_config,
)
def get_mm_item_iterator(
self,
min_num_mm_items: int,
max_num_mm_items: int,
bucket_config: dict[tuple[int, int, int], float],
limit_mm_per_prompt: dict[str, int],
) -> Iterator[tuple[int,int, int]]:
"""
Iterator over the multimodal items for each request
whose size is between min_num_mm_items and max_num_mm_items.
Loop over the bucket config and sample a multimodal item.
Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt
for all modalities is reached.
Note:
- This function operates on a per-request shallow copy of
`bucket_config` (tuple->float). The original dict passed to
`sample` is not mutated. If this ever changes, a test
is implemented and will fail.
"""
# Get the number of multimodal items to sample
request_num_mm_items = int(
self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
)
# If request_num_mm_items is 0, yield an empty iterator
if request_num_mm_items == 0:
return
# Initialize modality counters
modality_counter = {self.map_config_to_modality(k): 0
for k in bucket_config}
# Copy the bucket config to avoid modifying the original
bucket_config_copy = bucket_config.copy()
# Loop over the number of multimodal items to sample
while sum(modality_counter.values()) < request_num_mm_items:
# Sample a multimodal item config
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
p=list(bucket_config_copy.values()))
modality = self.map_config_to_modality(mm_item_config)
# Check that modality count is less than limit per prompt
if modality_counter[modality] < limit_mm_per_prompt[modality]:
modality_counter[modality] += 1
yield (
mm_item_config
)
else:
# If the counter is greater than the limit per prompt
# set all multimodal items of this modality to 0
for k, v in bucket_config_copy.items():
if self.map_config_to_modality(k) == modality:
bucket_config_copy[k] = 0
# If all configs are 0, break the loop
# This should not happen as request_num_mm_items is at most
# the sum of limit_mm_per_prompt for all modalities
if all(v == 0 for v in bucket_config_copy.values()):
logger.warning("Exhausted all multimodal items "
"of modality %s",
modality)
break
# Renormalize the bucket config
bucket_config_copy = self.normalize_bucket_config(
bucket_config_copy)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
bucket_config: dict[tuple[int, int, int], float] =
DEFAULT_MM_ITEM_BUCKET_CONFIG,
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
**kwargs,
) -> list[SampleRequest]:
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero.
if any(self.map_config_to_modality(cfg) == "video" and p > 0
for cfg, p in bucket_config.items()):
raise NotImplementedError("Video sampling not implemented; "
"set its probability to 0.")
# Get the sampling parameters for the dataset
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
)
(
min_num_mm_items,
max_num_mm_items,
limit_mm_per_prompt,
bucket_config,
) = self.get_mm_item_sampling_params(
base_items_per_request,
num_mm_items_range_ratio,
limit_mm_per_prompt,
bucket_config,
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
# Add synthetic multimodal items to each request
mm_requests = []
for i in range(num_requests):
prompt, total_input_len = self.generate_token_sequence(
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
# Get multimodal item iterator for a given request
mm_item_iterator = self.get_mm_item_iterator(
min_num_mm_items,
max_num_mm_items,
bucket_config,
limit_mm_per_prompt,
)
mm_content = cast(list[dict[str, Any]], [
self.generate_mm_item(mm_item_config)
for mm_item_config in mm_item_iterator
])
if enable_multimodal_chat:
# NOTE: For now this option is only provided for completeness
# given that the serve.py benchmark currently does not use it.
mm_chat_prompt: Any = prompt
mm_chat_prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
sample_request = SampleRequest(
prompt=mm_chat_prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
multi_modal_data=None,
request_id=request_id_prefix + str(i),
)
else:
sample_request = SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
))
return requests
multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
)
mm_requests.append(sample_request)
return mm_requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
......@@ -432,9 +958,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
samples: list = []
ind = 0
for entry in self.data:
if len(samples) >= num_requests:
break
......@@ -455,9 +983,10 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len
is not None):
continue
# TODO: Also support ShareGPT4Video.
if image_path := entry.get("image"):
mm_content = process_image(image_path)
elif video_path := entry.get("video"):
mm_content = process_video(video_path)
else:
mm_content = None
if enable_multimodal_chat:
......@@ -470,8 +999,10 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len,
lora_request=lora_request,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
))
self.maybe_oversample_requests(samples, num_requests)
ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
......@@ -488,8 +1019,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
type=str,
default="random",
choices=[
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
"prefix_repetition"
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"custom", "prefix_repetition"
],
help="Name of the dataset to benchmark on.",
)
......@@ -589,6 +1120,103 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."),
)
random_group.add_argument(
"--random-batch-size",
type=int,
default=1,
help=("Batch size for random sampling. "
"Only used for embeddings benchmark."),
)
# random multimodal dataset options
random_mm_group = parser.add_argument_group(
"random multimodal dataset options extended from random dataset")
random_mm_group.add_argument(
"--random-mm-base-items-per-request",
type=int,
default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
help=(
"Base number of multimodal items per request for random-mm. "
"Actual per-request count is sampled around this base using "
"--random-mm-num-mm-items-range-ratio."
),
)
random_mm_group.add_argument(
"--random-mm-num-mm-items-range-ratio",
type=float,
default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
help=(
"Range ratio r in [0, 1] for sampling items per request. "
"We sample uniformly from the closed integer range "
"[floor(n*(1-r)), ceil(n*(1+r))] "
"where n is the base items per request. "
"r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
"to the sum of per-modality limits from "
"--random-mm-limit-mm-per-prompt. "
"An error is raised if the computed min exceeds the max."
),
)
random_mm_group.add_argument(
"--random-mm-limit-mm-per-prompt",
type=json.loads,
default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
help=(
"Per-modality hard caps for items attached per request, e.g. "
"'{\"image\": 3, \"video\": 0}'. The sampled per-request item "
"count is clamped to the sum of these limits. When a modality "
"reaches its cap, its buckets are excluded and probabilities are "
"renormalized."
"OBS.: Only image sampling is supported for now."
),
)
def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]:
# If already a dict (e.g., programmatic call), normalize keys
def normalize(d: dict) -> dict[tuple[int, int, int], float]:
out: dict[tuple[int, int, int], float] = {}
for k, val in d.items():
key = k
if isinstance(key, str):
with suppress(Exception):
key = ast.literal_eval(key)
if not (isinstance(key, tuple) and len(key) == 3
and all(isinstance(x, int) for x in key)):
raise ValueError(
f"Invalid bucket key {k!r}. Expected tuple (H, W, T)."
)
out[(int(key[0]), int(key[1]), int(key[2]))] = float(val)
return out
if isinstance(v, dict):
return normalize(v)
if isinstance(v, str):
# Python literal (supports tuple keys)
parsed = ast.literal_eval(v)
if not isinstance(parsed, dict):
raise ValueError("Bucket config must parse to a dict.")
return normalize(parsed)
raise ValueError("Unsupported value for --random-mm-bucket-config.")
random_mm_group.add_argument(
"--random-mm-bucket-config",
type=_parse_mm_bucket_config,
default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
help=(
"The bucket config is a dictionary mapping a multimodal item"
"sampling configuration to a probability."
"Currently allows for 2 modalities: images and videos. "
"An bucket key is a tuple of (height, width, num_frames)"
"The value is the probability of sampling that specific item. "
"Example: "
"--random-mm-bucket-config "
"{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
"First item: images with resolution 256x256 w.p. 0.5"
"Second item: images with resolution 720x1280 w.p. 0.4 "
"Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
"OBS.: If the probabilities do not sum to 1, they are normalized."
"OBS bis.: Only image sampling is supported for now."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
......@@ -647,6 +1275,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "sonnet":
......@@ -660,6 +1289,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
)
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
......@@ -671,6 +1301,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "hf":
......@@ -716,10 +1347,11 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
"openai-chat",
"openai-audio",
]:
# multi-modal benchmark is only available on OpenAI Chat backend.
# multi-modal benchmark is only available on OpenAI Chat
# endpoint-type.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' and "
"'openai-audio' backend.")
"'openai-audio' endpoint-type.")
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
......@@ -730,31 +1362,54 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
)
else:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"sharegpt":
lambda: ShareGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
),
"burstgpt":
lambda: BurstGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
"random":
lambda: RandomDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,
),
"random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
batchsize=args.random_batch_size,
),
"random-mm":
lambda: RandomMultiModalDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
range_ratio=args.random_range_ratio,
input_len=args.random_input_len,
output_len=args.random_output_len,
base_items_per_request=args.random_mm_base_items_per_request,
limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
bucket_config=args.random_mm_bucket_config,
request_id_prefix=args.request_id_prefix,
),
"prefix_repetition":
lambda: PrefixRepetitionRandomDataset(
......@@ -766,10 +1421,18 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
suffix_len=args.prefix_repetition_suffix_len,
num_prefixes=args.prefix_repetition_num_prefixes,
output_len=args.prefix_repetition_output_len,
request_id_prefix=args.request_id_prefix,
),
}
try:
# Enforce endpoint compatibility for multimodal datasets.
if args.dataset_name == "random-mm" and args.endpoint_type not in [
"openai-chat"]:
raise ValueError(
"Multi-modal content (images) is only supported on "
"'openai-chat' backend."
)
input_requests = dataset_mapping[args.dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
......@@ -839,10 +1502,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["prompt"]
......@@ -864,8 +1528,10 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -909,6 +1575,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
# Calculate average token length for a poem line.
......@@ -934,6 +1601,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines]
samples = []
ind = 0
while len(samples) < num_requests:
extra_lines = random.choices(self.data,
k=num_input_lines - num_prefix_lines)
......@@ -949,7 +1617,9 @@ class SonnetDataset(BenchmarkDataset):
if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(ind),
))
ind += 1
return samples
......@@ -1000,6 +1670,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
samples = []
......@@ -1020,6 +1691,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
request_id=request_id_prefix + str(i),
))
return samples
......@@ -1075,11 +1747,13 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(
lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
ind = 0
dynamic_output = output_len is None
for item in filtered_data:
......@@ -1111,8 +1785,11 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1141,12 +1818,13 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
......@@ -1168,8 +1846,10 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1198,15 +1878,18 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['input']}\n\n{item['instruction']} Just output \
the code, do not include any explanation."
prompt = (
f"{item['input']}\n\n{item['instruction']} Just output "
"the code, do not include any explanation."
)
# apply template
prompt = tokenizer.apply_chat_template(
......@@ -1224,8 +1907,10 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1255,13 +1940,14 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0]
......@@ -1282,8 +1968,10 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1305,8 +1993,10 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs) -> list:
sampled_requests = []
ind = 0
dynamic_output = output_len is None
for item in self.data:
......@@ -1331,8 +2021,12 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
request_id=request_id_prefix + str(ind),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1403,13 +2097,14 @@ class NextEditPredictionDataset(HuggingFaceDataset):
}
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
request_id_prefix: str = "",
**kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
for sample in self.data:
for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample)
samples.append(
SampleRequest(
......@@ -1417,10 +2112,11 @@ class NextEditPredictionDataset(HuggingFaceDataset):
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids),
request_id=request_id_prefix + str(i),
))
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests)
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
......@@ -1470,6 +2166,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = (output_len
......@@ -1477,6 +2174,7 @@ class ASRDataset(HuggingFaceDataset):
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
ind = 0
skipped = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
......@@ -1496,7 +2194,9 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
))
ind += 1
if skipped:
logger.warning(
"%d samples discarded from dataset due to"
......@@ -1504,7 +2204,8 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1541,11 +2242,13 @@ class MLPerfDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
# Force dynamic output length based on reference completion.
dynamic_output = output_len is None
sampled_requests: list[SampleRequest] = []
ind = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
......@@ -1580,10 +2283,13 @@ class MLPerfDataset(HuggingFaceDataset):
prompt=prompt_formatted,
prompt_len=prompt_len,
expected_output_len=expected_output_len,
request_id=request_id_prefix + str(ind),
)
)
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1616,6 +2322,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
suffix_len: int = DEFAULT_SUFFIX_LEN,
num_prefixes: int = DEFAULT_NUM_PREFIXES,
output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
vocab_size = tokenizer.vocab_size
......
......@@ -9,7 +9,7 @@ import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union
import aiohttp
from tqdm.asyncio import tqdm
......@@ -28,9 +28,10 @@ class RequestFuncInput:
model_name: Optional[str] = None
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict | list[dict]] = None
multi_modal_content: Optional[Union[dict, list[dict]]] = None
ignore_eos: bool = False
language: Optional[str] = None
request_id: Optional[str] = None
@dataclass
......@@ -68,7 +69,7 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
payload = {
"model": request_func_input.model_name \
"model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
......@@ -87,6 +88,8 @@ async def async_request_openai_completions(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
......@@ -210,6 +213,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
......@@ -311,6 +316,8 @@ async def async_request_openai_audio(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
# Send audio file
def to_bytes(y, sr):
......@@ -387,12 +394,61 @@ async def async_request_openai_audio(
return output
async def async_request_openai_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
):
api_url = request_func_input.api_url
assert api_url.endswith(
"embeddings"
), "OpenAI Embeddings API URL must end with 'embeddings'."
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
payload = {
"model": request_func_input.model,
"input": request_func_input.prompt,
}
output = RequestFuncOutput()
st = time.perf_counter()
try:
async with session.post(
url=api_url,
headers=headers,
json=payload
) as response:
if response.status == 200:
output.latency = time.perf_counter() - st
data = await response.json()
output.success = True
output.generated_text = ""
output.prompt_len = data.get(
"usage", {}).get(
"prompt_tokens", 0)
else:
output.success = False
output.error = response.reason or ""
except Exception as e:
output.success = False
output.error = str(e)
if pbar:
pbar.update(1)
return output
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = {
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
"openai-embeddings": async_request_openai_embeddings,
}
OPENAI_COMPATIBLE_BACKENDS = [
......
......@@ -54,7 +54,12 @@ class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any):
if isinstance(o, dict):
return {k: self.clear_inf(v) for k, v in o.items()}
return {
str(k)
if not isinstance(k, (str, int, float, bool, type(None)))
else k: self.clear_inf(v)
for k, v in o.items()
}
elif isinstance(o, list):
return [self.clear_inf(v) for v in o]
elif isinstance(o, float) and math.isinf(o):
......
......@@ -26,6 +26,7 @@ import warnings
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Literal, Optional
import aiohttp
......@@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
class TaskType(Enum):
GENERATION = "generation"
EMBEDDING = "embedding"
@dataclass
class BenchmarkMetrics:
completed: int
......@@ -75,6 +81,16 @@ class BenchmarkMetrics:
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
@dataclass
class EmbedBenchmarkMetrics:
completed: int
total_input: int
request_throughput: float
total_token_throughput :float
mean_e2el_ms: float
std_e2el_ms: float
median_e2el_ms: float
percentiles_e2el_ms: float
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
......@@ -189,6 +205,51 @@ async def get_request(
yield request, request_rates[request_index]
def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float]
) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests.
Args:
outputs: The outputs of the requests.
dur_s: The duration of the benchmark.
selected_percentiles: The percentiles to select.
Returns:
The calculated benchmark metrics.
"""
total_input = 0
completed = 0
e2els: list[float] = []
for i in range(len(outputs)):
if outputs[i].success:
e2els.append(outputs[i].latency)
completed += 1
total_input += outputs[i].prompt_len
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = EmbedBenchmarkMetrics(
completed=completed,
total_input=total_input,
request_throughput=completed / dur_s,
total_token_throughput=total_input / dur_s,
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles
],
)
return metrics
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
......@@ -334,7 +395,15 @@ async def benchmark(
ramp_up_end_rps: Optional[int] = None,
ready_check_timeout_sec: int = 600,
):
task_type = (
TaskType.EMBEDDING
if api_url.endswith("/v1/embeddings")
else TaskType.GENERATION
)
if endpoint_type in ASYNC_REQUEST_FUNCS:
if task_type == TaskType.EMBEDDING:
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
else:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
else:
raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
......@@ -478,11 +547,12 @@ async def benchmark(
"timestamp": timestamp
})
last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = (
prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
request.request_id,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
......@@ -498,7 +568,8 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body)
extra_body=extra_body,
request_id=request_id,)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
......@@ -511,6 +582,7 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
if task_type == TaskType.GENERATION:
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
......@@ -519,6 +591,13 @@ async def benchmark(
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
else:
metrics = calculate_metrics_for_embeddings(
outputs=outputs,
dur_s=benchmark_duration,
selected_percentiles=selected_percentiles,
)
actual_output_lens = 0
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
......@@ -527,22 +606,28 @@ async def benchmark(
max_concurrency))
if request_rate != float('inf'):
print("{:<40} {:<10.2f}".format("Request rate configured (RPS):",
request_rate ))
request_rate))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
if isinstance(metrics, BenchmarkMetrics):
print("{:<40} {:<10}".format(
"Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
if isinstance(metrics, BenchmarkMetrics):
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
if isinstance(metrics, BenchmarkMetrics):
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
......@@ -560,6 +645,16 @@ async def benchmark(
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
else:
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"request_throughput": metrics.request_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"errors": [output.error for output in outputs],
}
if rps_change_events:
result["rps_change_events"] = rps_change_events
......@@ -596,9 +691,10 @@ async def benchmark(
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
if task_type == TaskType.GENERATION:
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric(
"tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
......@@ -730,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
......@@ -741,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
......@@ -865,6 +961,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
parser.add_argument(
"--request-id-prefix",
type=str,
required=False,
default="benchmark-serving",
help="Specify the prefix of request id.",
)
sampling_group = parser.add_argument_group("sampling parameters")
sampling_group.add_argument(
......@@ -958,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
async def main_async(args: argparse.Namespace) -> dict[str, Any]:
print(args)
random.seed(args.seed)
......
......@@ -435,6 +435,14 @@ def validate_args(args):
raise ValueError(
"Tokenizer must be the same as the model for MII backend.")
# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1:
raise ValueError(
"Data parallel is not supported in offline benchmark, "
"please use benchmark serving instead"
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
register_replacement)
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def silu_mul_pattern_static(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
FUSED_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
}
silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant"))
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
quant_key: QuantKey,
):
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
assert self.quant_key in FUSED_OPS, \
f"unsupported fusion scheme {self.quant_key}"
self.FUSED_OP = FUSED_OPS[self.quant_key]
def empty_quant(self, *args, **kwargs):
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self, symmetric: bool = True):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(quant_key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale)
return at2[1]
def silu_mul_replacement_static(result: torch.Tensor,
result_silu_mul: torch.Tensor,
def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
scale=scale)
return at[1]
inputs = [
self.empty_quant(5, 4), # result
empty_bf16(5, 4), # result_silu_mul
empty_bf16(5, 4), # input
empty_fp32(1, 1) # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def __init__(self):
super().__init__(kNvfp4Quant)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, output_scale: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(self.QUANT_OP,
output=result,
input=at1[1],
output_scale=output_scale,
input_scale=scale)
return at2[1], at2[2]
def empty_fp8(*args, **kwargs):
fp8 = current_platform.fp8_dtype()
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
def replacement(result: torch.Tensor, output_scale: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
result_block_scale=output_scale,
input=input,
input_global_scale=scale)
return at[1], at[2]
inputs = [
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # result_silu_mul
empty_bf16(5, 64), # input
empty_fp32(1, 1) # scale
]
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmInductorPass):
......@@ -61,21 +162,19 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass")
inputs = [
empty_fp8(5, 4), # Quant output
empty_bf16(5, 4), # Silu_and_mul output
empty_bf16(5, 4), # Input
empty_fp32(1, 1) # Scale
]
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
if silu_and_mul_nvfp4_quant_supported:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
def __call__(self, graph: torch.fx.Graph):
self.begin()
......@@ -87,3 +186,8 @@ class ActivationQuantFusionPass(VllmInductorPass):
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()
def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern)
......@@ -271,7 +271,7 @@ def split_graph(graph: fx.GraphModule,
outputs.append(
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
# sort by intetger graph_id, rather than string name
# sort by integer graph_id, rather than string name
outputs.sort(key=lambda x: x.graph_id)
return split_gm, outputs
......@@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: list[str], vllm_config: VllmConfig,
graph_pool, vllm_backend: "VllmBackend"):
vllm_backend: "VllmBackend"):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
......@@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
......@@ -405,7 +403,6 @@ class VllmBackend:
vllm_config: VllmConfig
compilation_config: CompilationConfig
graph_pool: Any
_called: bool = False
# the graph we compiled
graph: fx.GraphModule
......@@ -427,19 +424,12 @@ class VllmBackend:
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. launguage_model, vision_model, etc.
# e.g. language_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
global_graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()
......@@ -484,7 +474,7 @@ class VllmBackend:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affects the computation graph.
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)
......@@ -586,7 +576,7 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.vllm_config, self.graph_pool,
self.vllm_config,
self).run(*example_inputs)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
......
......@@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
"""
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
runtime_mode: CUDAGraphMode, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
......@@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol):
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
......
......@@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
......@@ -18,6 +19,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -348,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class AsyncTPPass(VllmInductorPass):
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
......@@ -401,6 +404,18 @@ if flashinfer_comm is not None:
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
try:
_FI_MAX_SIZES.update({
int(k): int(float(v) * MiB)
for k, v in
envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
})
except Exception as e:
raise ValueError(
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
+ str(e)) from e
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
......@@ -465,7 +480,8 @@ if flashinfer_comm is not None:
quant_out=quant_out,
scale_out=scale_out,
# in vllm we only support swizzled layout
layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
layout_code=flashinfer_comm.QuantizationSFLayout.
SWIZZLED_128x4,
scale_factor=scale_factor,
)
else:
......@@ -1107,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
# in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
self.register_patterns()
@enable_fake_mode
def register_patterns(self):
for epsilon in [1e-5, 1e-6]:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
......
......@@ -67,11 +67,9 @@ class CUDAGraphWrapper:
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
......@@ -81,7 +79,9 @@ class CUDAGraphWrapper:
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
if self.graph_pool is None:
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
......
......@@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool:
return getattr(cls, IGNORE_COMPILE_KEY, False)
@overload
def support_torch_compile(
*,
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> Callable[[_T], _T]:
...
@overload
def support_torch_compile(
*,
......@@ -69,6 +77,7 @@ def support_torch_compile(
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
......@@ -118,6 +127,11 @@ def support_torch_compile(
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
`enable_if` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
"""
def cls_decorator_helper(cls: _T) -> _T:
......@@ -149,7 +163,8 @@ def support_torch_compile(
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}")
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
enable_if)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
......@@ -162,6 +177,7 @@ def support_torch_compile(
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
......@@ -182,13 +198,14 @@ def _support_torch_compile(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo() or _should_ignore_torch_compile(
self.__class__)
self.__class__) or not enable_compile
if self.do_not_compile:
return
......@@ -267,8 +284,24 @@ def _support_torch_compile(
code.co_filename)
return inline_call(parent, func, args, kwargs)
# Disable the C++ compilation of symbolic shape guards. C++-fication
# of symbolic shape guards can improve guard overhead. But, since
# vllm skip guards anyways, setting this flag to False can improve
# compile time.
dynamo_config_patches = {}
try:
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
dynamo_config_patches[
"enable_cpp_symbolic_shape_guards"] = False
except AttributeError:
# Note: this config is not available in torch 2.6, we can skip
# if the config doesn't exist
logger.debug(
"enable_cpp_symbolic_shape_guards config not available")
with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call):
patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches):
output = self.compiled_callable(*args, **kwargs)
return output
......
......@@ -9,6 +9,7 @@ import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
......@@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug("XPU platform does not support fix functionalization"
"pass currently.")
return
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
......@@ -89,6 +97,15 @@ class FixFunctionalizationPass(VllmInductorPass):
# node,
# mutated_args,
# args=('result', 'input', 'scale'))
# elif hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant"
# ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
# mutated_args = {1: 'result', 2: 'result_block_scale'}
# self.defunctionalize(graph,
# node,
# mutated_args,
# args=('result', 'result_block_scale',
# 'input', 'input_global_scale'))
else:
continue # skip the count
......
......@@ -12,15 +12,18 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
from .inductor_pass import enable_fake_mode
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs):
......@@ -31,41 +34,13 @@ def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
symmetric: bool = True
def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'a' if not self.symmetric else ''}symmetric)")
# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym:
......@@ -75,6 +50,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8DynamicTokenSym:
# torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
......@@ -187,10 +165,8 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
......@@ -244,10 +220,8 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, key)
......@@ -337,10 +311,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)
......@@ -435,10 +409,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)
......@@ -556,6 +530,7 @@ class FusionPass(VllmInductorPass):
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
@enable_fake_mode
def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern:
class AttentionQuantPattern(ABC):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer_name: str,
num_heads: int,
head_size: int,
quant_dtype: torch.dtype,
symmetric=True,
layer: Attention,
quant_key: QuantKey,
):
self.layer_name = layer_name
self.num_heads = num_heads
self.head_size = head_size
self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
......@@ -48,31 +55,64 @@ class AttentionStaticQuantPattern:
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def register_if_supported(self, pm_pass: PatternMatcherPass,
layer: Attention):
if layer.impl.fused_output_quant_supported(self.quant_dtype,
self.quant_key.static,
self.quant_key.group_shape):
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
symmetric: bool = True,
):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(layer, quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_attn,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
output=output_attn,
layer_name=self.layer_name,
output_scale=None)
attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size])
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
result=output_quant,
input=attn_out_view,
......@@ -82,47 +122,116 @@ class AttentionStaticQuantPattern:
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_quant,
[-1, self.num_heads, self.head_size])
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device)
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale)
output_scale=scale,
output_block_scale=None)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads * self.head_size), # attn_output
self.empty_quant(5, self.num_heads *
self.head_size), # quant_output
empty_bf16(5, self.num_heads, self.head_size), # attn_output
self.empty_quant(5,
self.num_heads * self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
def wrap_trace_fn(process_fx, trace_fn):
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention):
super().__init__(layer, kNvfp4Quant)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(
output_scale, FP8_DTYPE)
at2 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view)
output = RESHAPE_OP(at2[1],
[-1, self.num_heads * self.head_size // 2])
return output, at2[2]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size //
2), # output_quant
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement(
pattern, replacement, inputs,
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
class AttnFusionPass(VllmInductorPass):
......@@ -138,32 +247,42 @@ class AttnFusionPass(VllmInductorPass):
support are attention kernels, which need to support fusing output quant.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.static_fwd_ctx = config.compilation_config.static_forward_context
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
for key, layer in self.static_fwd_ctx.items():
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
layer.head_size,
current_platform.fp8_dtype())
pattern.register_if_supported(self.patterns, layer)
if len(self.static_fwd_ctx) == 0:
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern_fp8.register_if_supported(self.patterns)
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but "
"CompilationConfig.static_forward_context is empty. "
"Cannot access attention layers so no fusion "
"patterns were registered.")
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered.")
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin()
self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph.eliminate_dead_code()
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()
def uuid(self):
return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern)
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib
import inspect
import json
......@@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import torch
from torch import fx
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.utils import is_torch_equal_or_newer
......@@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
with torch._guards.tracing(
None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new
......@@ -43,7 +43,7 @@ cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled():
# used to monitor whether an cudagraph capturing is legal at runtime.
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global cudagraph_capturing_enabled
......
......@@ -8,13 +8,13 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import FusionPass
from .fusion_attn import AttnFusionPass
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .activation_quant_fusion import ActivationQuantFusionPass
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
......
......@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
......@@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
performance.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
......
......@@ -33,7 +33,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
CUDAGraphMode, PassConfig)
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig)
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger
......@@ -191,6 +192,16 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
yield a, b
a = b
try:
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
except (OSError, KeyError, TypeError):
# HACK: Python 3.13+ workaround - set missing __firstlineno__
# Workaround can be removed after we upgrade to pydantic==2.12.0
with open(inspect.getfile(cls)) as f:
for i, line in enumerate(f):
if f"class {cls.__name__}" in line and ":" in line:
cls.__firstlineno__ = i + 1
break
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
......@@ -246,8 +257,14 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
MMEncoderTPMode = Literal["weights", "data"]
class LogprobsMode(enum.Enum):
RAW_LOGITS = "raw_logits"
RAW_LOGPROBS = "raw_logprobs"
PROCESSED_LOGITS = "processed_logits"
PROCESSED_LOGPROBS = "processed_logprobs"
@config
......@@ -351,12 +368,13 @@ class ModelConfig:
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode: LogprobsMode = "raw_logprobs"
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
Raw means the values before applying any logit processors, like bad words.
Processed means the values after applying all processors, including
temperature and top_k/top_p.
"""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
......@@ -419,7 +437,7 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
......@@ -428,6 +446,19 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
......@@ -470,6 +501,8 @@ class ModelConfig:
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
"""One or more logits processors' fully-qualified class names or class
definitions"""
io_processor_plugin: Optional[str] = None
"""IOProcessor plugin name to load at model startup"""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
......@@ -845,22 +878,25 @@ class ModelConfig:
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
if self._model_info.supports_multimodal:
if (self.mm_encoder_tp_mode == "data" and
not self._model_info.supports_multimodal_encoder_tp_data):
logger.warning_once(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`.")
self.mm_encoder_tp_mode = "weights"
return MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling)
skip_mm_profiling=self.skip_mm_profiling,
)
return None
def set_mm_processor_cache_gb(self, value: int) -> None:
mm_config = self.get_multimodal_config()
self.mm_processor_cache_gb = value
mm_config.mm_processor_cache_gb = value
def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
......@@ -1090,9 +1126,20 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = me_quant.QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
"fp8",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed-tensors",
"experts_int8",
"quark",
"modelopt_fp4",
"bitblas",
"gptq_bitblas",
"inc",
"petit_nvfp4",
]
if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods,
......@@ -1115,7 +1162,6 @@ class ModelConfig:
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides = [
"marlin",
"bitblas",
"gptq_marlin_24",
"gptq_marlin",
......@@ -1125,6 +1171,7 @@ class ModelConfig:
"moe_wna16",
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
......@@ -1457,7 +1504,8 @@ class ModelConfig:
from vllm.distributed.utils import get_pp_indices
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"):
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
......@@ -1657,29 +1705,8 @@ class ModelConfig:
return self.multimodal_config is not None
@property
def processor_return_mm_hashes(self) -> bool:
"""Whether the multi-modal processor should output hashes."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
@property
def enable_mm_processor_cache(self) -> bool:
"""Whether the multi-modal processor cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
return 0
return envs.VLLM_MM_INPUT_CACHE_GIB
def is_multimodal_raw_input_only_model(self) -> bool:
return self._model_info.supports_multimodal_raw_input_only
@property
def is_cross_encoder(self) -> bool:
......@@ -1690,10 +1717,6 @@ class ModelConfig:
def is_pp_supported(self) -> bool:
return self._model_info.supports_pp
@property
def is_multimodal_raw_input_supported(self) -> bool:
return self._model_info.supports_multimodal_raw_input
@property
def is_attention_free(self) -> bool:
return self._model_info.is_attention_free
......@@ -1904,7 +1927,8 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"]
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
@config
......@@ -2037,6 +2061,16 @@ class SpeculativeConfig:
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config
return hf_config
def __post_init__(self):
......@@ -2055,8 +2089,8 @@ class SpeculativeConfig:
if self.target_model_config and \
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type \
== "mimo"):
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe")):
# use the draft model from the same model:
self.model = self.target_model_config.model
elif self.method in ("ngram", "[ngram]"):
......@@ -2154,6 +2188,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
......@@ -2369,7 +2412,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp")
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
def __repr__(self) -> str:
method = self.method
......@@ -2401,8 +2444,8 @@ class LoRAConfig:
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
"""(Deprecated) Maximum size of extra vocabulary that can be present in a
LoRA adapter. Will be removed in v0.12.0."""
lora_vocab_padding_size: ClassVar[int] = current_platform\
.get_lora_vocab_padding_size()
......@@ -2444,6 +2487,12 @@ class LoRAConfig:
return hash_str
def __post_init__(self):
# Deprecation warning for lora_extra_vocab_size
logger.warning(
"`lora_extra_vocab_size` is deprecated and will be removed "
"in v0.12.0. Additional vocabulary support for "
"LoRA adapters is being phased out.")
# Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
......@@ -2508,7 +2557,7 @@ class MultiModalConfig:
`{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""
The size (in GiB) of the multi-modal processor cache, which is used to
......@@ -2519,6 +2568,22 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended).
"""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""
Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.
"""
interleave_mm_strings: bool = False
"""
Enable fully interleaved support for multimodal prompts.
......@@ -2988,7 +3053,8 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
"lm-format-enforcer"]
@config
......@@ -3551,7 +3617,7 @@ class VllmConfig:
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() or current_platform.is_xpu():
# if cudagraph_mode is not explicitly set by users, set default
# value
if self.compilation_config.cudagraph_mode is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment