Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2b52805
Commit
d2b52805
authored
Sep 07, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori
parents
9a521c23
5438967f
Changes
511
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1657 additions
and
433 deletions
+1657
-433
vllm/beam_search.py
vllm/beam_search.py
+1
-1
vllm/benchmarks/datasets.py
vllm/benchmarks/datasets.py
+803
-96
vllm/benchmarks/lib/endpoint_request_func.py
vllm/benchmarks/lib/endpoint_request_func.py
+62
-6
vllm/benchmarks/lib/utils.py
vllm/benchmarks/lib/utils.py
+6
-1
vllm/benchmarks/serve.py
vllm/benchmarks/serve.py
+188
-83
vllm/benchmarks/throughput.py
vllm/benchmarks/throughput.py
+8
-0
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+139
-35
vllm/compilation/backends.py
vllm/compilation/backends.py
+5
-15
vllm/compilation/base_static_graph.py
vllm/compilation/base_static_graph.py
+1
-4
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+21
-1
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+4
-4
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+36
-3
vllm/compilation/fix_functionalization.py
vllm/compilation/fix_functionalization.py
+17
-0
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+22
-47
vllm/compilation/fusion_attn.py
vllm/compilation/fusion_attn.py
+191
-72
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+20
-0
vllm/compilation/monitor.py
vllm/compilation/monitor.py
+1
-1
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+1
-1
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+2
-0
vllm/config/__init__.py
vllm/config/__init__.py
+129
-63
No files found.
Too many changes to show.
To preserve performance only
511 of 511+
files are displayed.
Plain diff
Email patch
vllm/beam_search.py
View file @
d2b52805
...
...
@@ -18,7 +18,7 @@ class BeamSearchSequence:
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens include
s
the prompt.
# The tokens include the prompt.
tokens
:
list
[
int
]
logprobs
:
list
[
dict
[
int
,
Logprob
]]
lora_request
:
Optional
[
LoRARequest
]
=
None
...
...
vllm/benchmarks/datasets.py
View file @
d2b52805
...
...
@@ -11,17 +11,21 @@ generation. Supported dataset types include:
- HuggingFace
- VisionArena
"""
import
ast
import
base64
import
io
import
json
import
logging
import
math
import
random
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Mapping
from
collections.abc
import
Iterator
,
Mapping
from
contextlib
import
suppress
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
functools
import
cache
from
io
import
BytesIO
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
cast
import
numpy
as
np
from
PIL
import
Image
...
...
@@ -69,13 +73,14 @@ class SampleRequest:
Represents a single inference request for benchmarking.
"""
prompt
:
Union
[
str
,
Any
]
prompt
:
Union
[
str
,
list
[
str
]
]
prompt_len
:
int
expected_output_len
:
int
multi_modal_data
:
Optional
[
Union
[
MultiModalDataDict
,
dict
,
list
[
dict
]]
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
request_id
:
Optional
[
str
]
=
None
# -----------------------------------------------------------------------------
...
...
@@ -112,7 +117,9 @@ class BenchmarkDataset(ABC):
def
apply_multimodal_chat_transformation
(
self
,
prompt
:
str
,
mm_content
:
Optional
[
MultiModalDataDict
]
=
None
)
->
list
[
dict
]:
mm_content
:
Optional
[
Union
[
MultiModalDataDict
,
dict
,
list
[
dict
]]
]
=
None
)
->
list
[
dict
]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
...
...
@@ -120,7 +127,15 @@ class BenchmarkDataset(ABC):
"""
content
=
[{
"text"
:
prompt
,
"type"
:
"text"
}]
if
mm_content
is
not
None
:
if
isinstance
(
mm_content
,
list
):
content
.
extend
(
cast
(
list
[
dict
[
str
,
Any
]],
mm_content
))
elif
isinstance
(
mm_content
,
dict
):
content
.
append
(
mm_content
)
else
:
raise
TypeError
(
"Could not process multimodal content of type: "
+
f
"
{
type
(
mm_content
)
}
"
)
return
[{
"role"
:
"user"
,
"content"
:
content
}]
def
load_data
(
self
)
->
None
:
...
...
@@ -183,7 +198,8 @@ class BenchmarkDataset(ABC):
@
abstractmethod
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
)
->
list
[
SampleRequest
]:
num_requests
:
int
,
request_id_prefix
:
str
=
""
)
->
list
[
SampleRequest
]:
"""
Abstract method to generate sample requests from the dataset.
...
...
@@ -194,6 +210,8 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns:
list[SampleRequest]: A list of sample requests generated from the
...
...
@@ -201,8 +219,12 @@ class BenchmarkDataset(ABC):
"""
raise
NotImplementedError
(
"sample must be implemented in subclasses."
)
def
maybe_oversample_requests
(
self
,
requests
:
list
[
SampleRequest
],
num_requests
:
int
)
->
None
:
def
maybe_oversample_requests
(
self
,
requests
:
list
[
SampleRequest
],
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
)
->
None
:
"""
Oversamples the list of requests if its size is less than the desired
number.
...
...
@@ -211,11 +233,17 @@ class BenchmarkDataset(ABC):
requests (List[SampleRequest]): The current list of sampled
requests.
num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
"""
if
len
(
requests
)
<
num_requests
:
random
.
seed
(
self
.
random_seed
)
additional
=
random
.
choices
(
requests
,
k
=
num_requests
-
len
(
requests
))
additional
=
deepcopy
(
random
.
choices
(
requests
,
k
=
num_requests
-
len
(
requests
))
)
for
i
in
range
(
len
(
additional
)):
req
=
additional
[
i
]
req
.
request_id
=
request_id_prefix
+
str
(
len
(
requests
)
+
i
)
requests
.
extend
(
additional
)
logger
.
info
(
"Oversampled requests to reach %d total samples."
,
num_requests
)
...
...
@@ -266,7 +294,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
"""
Process a single image input and return a multimedia content dictionary.
Supports th
ree
input types:
Supports th
e following
input types:
1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
containing raw image data. - Loads the bytes as a PIL.Image.Image.
...
...
@@ -306,94 +334,592 @@ def process_image(image: Any) -> Mapping[str, Any]:
" or str or dictionary with raw image bytes."
)
def
process_video
(
video
:
Any
)
->
Mapping
[
str
,
Any
]:
"""
Process a single video input and return a multimedia content dictionary.
Supports the following input types:
1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
containing raw video data.
2. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
Raises:
ValueError: If the input is not a supported type.
"""
if
isinstance
(
video
,
dict
)
and
'bytes'
in
video
:
video_bytes
=
video
[
'bytes'
]
video_base64
=
base64
.
b64encode
(
video_bytes
).
decode
(
"utf-8"
)
return
{
"type"
:
"video_url"
,
"video_url"
:
{
"url"
:
f
"data:video/mp4;base64,
{
video_base64
}
"
},
}
if
isinstance
(
video
,
str
):
video_url
=
(
video
if
video
.
startswith
(
(
"http://"
,
"file://"
))
else
f
"file://
{
video
}
"
)
return
{
"type"
:
"video_url"
,
"video_url"
:
{
"url"
:
video_url
}}
raise
ValueError
(
f
"Invalid video input
{
video
}
. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`."
# noqa: E501
)
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class
RandomDataset
(
BenchmarkDataset
):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
Strategy:
- Sample input/output token lengths per request from integer-uniform ranges
around configured means (controlled by range_ratio).
- Prepend a fixed random prefix of length prefix_len.
- Generate the remaining tokens as a reproducible sequence:
(offset + index + arange(input_len)) % vocab_size.
- Decode then re-encode/truncate to ensure prompt token counts match.
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
"""
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN
=
0
DEFAULT_RANGE_RATIO
=
0.0
DEFAULT_INPUT_LEN
=
1024
DEFAULT_OUTPUT_LEN
=
128
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
def
__init__
(
self
,
**
kwargs
)
->
None
:
super
().
__init__
(
**
kwargs
)
random
.
seed
(
self
.
random_seed
)
np
.
random
.
seed
(
self
.
random_seed
)
# Use numpy's default_rng for deterministic sampling
# Do not use random.seed() or np.random.seed() elsewhere in this class.
# This ensures that the RNG is isolated from global RNG state.
self
.
_rng
=
np
.
random
.
default_rng
(
self
.
random_seed
)
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
prefix_len
:
int
=
DEFAULT_PREFIX_LEN
,
range_ratio
:
float
=
DEFAULT_RANGE_RATIO
,
input_len
:
int
=
DEFAULT_INPUT_LEN
,
output_len
:
int
=
DEFAULT_OUTPUT_LEN
,
batchsize
:
int
=
1
,
**
kwargs
,
)
->
list
[
SampleRequest
]:
# Enforce range_ratio < 1
assert
range_ratio
<
1.0
,
(
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
input_lens
,
output_lens
,
offsets
=
self
.
get_sampling_params
(
num_requests
,
range_ratio
,
input_len
,
output_len
,
tokenizer
)
# Generate prefix once
prefix_token_ids
=
self
.
get_prefix
(
tokenizer
,
prefix_len
)
vocab_size
=
tokenizer
.
vocab_size
num_special_tokens
=
tokenizer
.
num_special_tokens_to_add
()
real_input_len
=
input_len
-
num_special_tokens
prefix_token_ids
=
(
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
prefix_len
).
tolist
()
if
prefix_len
>
0
else
[])
requests
=
[]
for
i
in
range
(
num_requests
):
prompt
,
total_input_len
=
self
.
generate_token_sequence
(
tokenizer
=
tokenizer
,
prefix_token_ids
=
prefix_token_ids
,
prefix_len
=
prefix_len
,
vocab_size
=
vocab_size
,
input_len
=
int
(
input_lens
[
i
]),
offset
=
int
(
offsets
[
i
]),
index
=
i
,
)
requests
.
append
(
SampleRequest
(
prompt
=
prompt
,
prompt_len
=
total_input_len
,
expected_output_len
=
int
(
output_lens
[
i
]),
request_id
=
request_id_prefix
+
str
(
i
),
)
)
# only used for embeddings benchmark.
if
batchsize
>
1
:
batch_requests
=
[]
# Create batched requests
for
i
in
range
(
0
,
num_requests
,
batchsize
):
batch
=
requests
[
i
:
i
+
batchsize
]
batch_requests
.
append
(
SampleRequest
(
prompt
=
[
req
.
prompt
for
req
in
batch
],
prompt_len
=
sum
(
req
.
prompt_len
for
req
in
batch
),
expected_output_len
=
0
,
request_id
=
request_id_prefix
+
str
(
i
//
batchsize
),
)
)
requests
=
batch_requests
return
requests
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low
=
int
(
real_input_len
*
(
1
-
range_ratio
))
input_high
=
int
(
real_input_len
*
(
1
+
range_ratio
))
output_low
=
int
(
output_len
*
(
1
-
range_ratio
))
output_high
=
int
(
output_len
*
(
1
+
range_ratio
))
def
get_prefix
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
prefix_len
:
int
)
->
list
[
int
]:
"""
Get the prefix for the dataset.
"""
return
(
self
.
_rng
.
integers
(
0
,
tokenizer
.
vocab_size
,
size
=
prefix_len
).
tolist
()
if
prefix_len
>
0
else
[]
)
def
get_sampling_params
(
self
,
num_requests
:
int
,
range_ratio
:
float
,
input_len
:
int
,
output_len
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if
not
(
0.0
<=
range_ratio
<
1.0
):
raise
ValueError
(
"range_ratio must be in [0, 1)."
)
num_special_tokens
=
int
(
tokenizer
.
num_special_tokens_to_add
())
real_input_len
=
max
(
0
,
int
(
input_len
)
-
num_special_tokens
)
# Bounds use floor for low and ceil for high
input_low
=
math
.
floor
(
real_input_len
*
(
1
-
range_ratio
))
input_high
=
math
.
ceil
(
real_input_len
*
(
1
+
range_ratio
))
output_low
=
math
.
floor
(
output_len
*
(
1
-
range_ratio
))
output_high
=
math
.
ceil
(
output_len
*
(
1
+
range_ratio
))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low
=
max
(
output_low
,
1
)
if
input_low
>
input_high
:
raise
ValueError
(
"Invalid input sampling interval: "
f
"low=
{
input_low
}
> high=
{
input_high
}
"
)
if
output_low
>
output_high
:
raise
ValueError
(
"Invalid output sampling interval: "
f
"low=
{
output_low
}
> high=
{
output_high
}
"
)
# Add logging for debugging
logger
.
info
(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]"
,
input_low
,
input_high
,
output_low
,
output_high
)
input_low
,
input_high
,
output_low
,
output_high
,
)
input_lens
=
np
.
random
.
randint
(
input_low
,
input_high
+
1
,
input_lens
=
self
.
_rng
.
integers
(
input_low
,
input_high
+
1
,
size
=
num_requests
)
output_lens
=
np
.
random
.
randint
(
output_low
,
output_high
+
1
,
output_lens
=
self
.
_rng
.
integers
(
output_low
,
output_high
+
1
,
size
=
num_requests
)
offsets
=
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
num_requests
)
offsets
=
self
.
_rng
.
integers
(
0
,
tokenizer
.
vocab_size
,
size
=
num_requests
)
return
input_lens
,
output_lens
,
offsets
requests
=
[]
for
i
in
range
(
num_requests
):
inner_seq
=
((
offsets
[
i
]
+
i
+
np
.
arange
(
input_lens
[
i
]))
%
vocab_size
).
tolist
()
def
generate_token_sequence
(
self
,
*
,
tokenizer
:
PreTrainedTokenizerBase
,
prefix_token_ids
:
list
[
int
],
prefix_len
:
int
,
vocab_size
:
int
,
input_len
:
int
,
offset
:
int
,
index
:
int
,
)
->
tuple
[
str
,
int
]:
"""
Returns (prompt, total_input_len).
NOTE: After decoding the prompt we have to encode and decode it again.
This is done because in some cases N consecutive tokens
give a string tokenized into != N number of tokens.
For example for GPT2Tokenizer:
[6880, 6881] -> ['Ġcalls', 'here'] ->
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
To avoid uncontrolled change of the prompt length,
the encoded sequence is truncated before being decode again.
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq
=
((
offset
+
index
+
np
.
arange
(
input_len
))
%
vocab_size
).
tolist
()
token_sequence
=
prefix_token_ids
+
inner_seq
# Decode, then re-encode and truncate to preserve token count invariants
prompt
=
tokenizer
.
decode
(
token_sequence
)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
total_input_len
=
prefix_len
+
int
(
input_lens
[
i
])
total_input_len
=
prefix_len
+
int
(
input_len
)
re_encoded_sequence
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)[:
total_input_len
]
prompt
=
tokenizer
.
decode
(
re_encoded_sequence
)
total_input_len
=
len
(
re_encoded_sequence
)
requests
.
append
(
SampleRequest
(
return
prompt
,
total_input_len
# -----------------------------------------------------------------------------
# MultiModalDataset Implementation
# -----------------------------------------------------------------------------
class
RandomMultiModalDataset
(
RandomDataset
):
"""
Synthetic multimodal dataset (text + images) that extends RandomDataset.
Status:
- Images: supported via synthetic RGB data.
- Video: not yet supported (TODO: implement video generation method).
- Audio: not yet supported.
Sampling overview:
1) Number of items per request is sampled uniformly from the integer range
[floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
The maximum is further clamped to the sum of per-modality limits.
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) → probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
When a modality reaches its cap, all of its buckets are excluded and the
remaining probabilities are renormalized.
Example bucket configuration:
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
- Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16).
OBS.: Only image sampling is supported for now.
"""
IS_MULTIMODAL
=
True
# NOTE: video sampling is WIP. Setting it to 0.
DEFAULT_LIMIT_MM_PER_PROMPT
=
{
"image"
:
255
,
"video"
:
0
}
DEFAULT_BASE_ITEMS_PER_REQUEST
=
1
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO
=
0.0
DEFAULT_MM_ITEM_BUCKET_CONFIG
=
{
(
256
,
256
,
1
):
0.5
,
(
720
,
1280
,
1
):
0.5
,
(
720
,
1280
,
16
):
0.0
,
}
DEFAULT_ENABLE_MULTIMODAL_CHAT
=
False
def
__init__
(
self
,
**
kwargs
)
->
None
:
super
().
__init__
(
**
kwargs
)
def
generate_synthetic_image
(
self
,
width
:
int
,
height
:
int
)
->
Image
.
Image
:
"""Generate synthetic PIL image with random RGB values.
NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos.
We could consider a “low-freq” mode (e.g., noise blur)
to emulate network realism instead of max stress.
"""
random_pixels
=
self
.
_rng
.
integers
(
0
,
256
,
(
height
,
width
,
3
),
dtype
=
np
.
uint8
,
)
return
Image
.
fromarray
(
random_pixels
)
def
generate_synthetic_video
(
self
,
width
:
int
,
height
:
int
,
num_frames
:
int
)
->
Any
:
"""Generate synthetic video with random values.
TODO: Finish this method.
"""
raise
NotImplementedError
(
"Video sampling is WIP."
)
def
map_config_to_modality
(
self
,
config
:
tuple
[
int
,
int
,
int
])
->
str
:
"""Map the configuration to the modality."""
if
config
[
-
1
]
==
1
:
return
"image"
elif
config
[
-
1
]
>
1
:
return
"video"
else
:
raise
ValueError
(
f
"Invalid multimodal item configuration:
{
config
}
"
)
def
normalize_bucket_config
(
self
,
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
])
->
dict
[
tuple
[
int
,
int
,
int
],
float
]:
"""
Remove zero probability entries
and normalize the bucket config to sum to 1.
"""
# Raise error if value is negative
if
any
(
v
<
0
for
v
in
bucket_config
.
values
()):
raise
ValueError
(
"Bucket config values must be non-negative."
)
# Remove zero probability entries
bucket_config
=
{
k
:
v
for
k
,
v
in
bucket_config
.
items
()
if
v
>
0
}
# if bucket config is empty, raise error
if
not
bucket_config
:
raise
ValueError
(
"Got invalid bucket config. "
"Bucket config values must be non-zero."
)
# Normalize the remaining bucket config to sum to 1
total
=
sum
(
bucket_config
.
values
())
return
{
k
:
v
/
total
for
k
,
v
in
bucket_config
.
items
()}
def
generate_mm_item
(
self
,
mm_item_config
:
tuple
[
int
,
int
,
int
],
)
->
Mapping
[
str
,
Any
]:
"""
Create synthetic images and videos and
apply process_image/process_video respectively.
This follows the OpenAI API chat completions
https://github.com/openai/openai-python
"""
if
self
.
map_config_to_modality
(
mm_item_config
)
==
"image"
:
return
process_image
(
self
.
generate_synthetic_image
(
mm_item_config
[
1
],
mm_item_config
[
0
]))
elif
self
.
map_config_to_modality
(
mm_item_config
)
==
"video"
:
return
process_video
(
self
.
generate_synthetic_video
(
mm_item_config
[
1
],
mm_item_config
[
0
],
mm_item_config
[
2
]))
else
:
raise
ValueError
(
f
"Invalid multimodal item configuration: "
f
"
{
mm_item_config
}
"
)
def
get_mm_item_sampling_params
(
self
,
base_items_per_request
:
int
,
num_mm_items_range_ratio
:
float
,
limit_mm_per_prompt
:
dict
[
str
,
int
],
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
],
)
->
tuple
[
int
,
int
,
dict
[
str
,
int
],
dict
[
tuple
[
int
,
int
,
int
],
float
]]:
"""
Get the sampling parameters for the multimodal items.
"""
# Enforce num_mm_items_range_ratio <= 1
if
not
(
0.0
<=
num_mm_items_range_ratio
<=
1.0
):
raise
ValueError
(
"num_mm_items_range_ratio must be in [0, 1]."
)
# Ensure modalities to sample are in limit_mm_per_prompt
for
k
,
v
in
bucket_config
.
items
():
# get modality from bucket config
modality
=
self
.
map_config_to_modality
(
k
)
if
modality
not
in
limit_mm_per_prompt
:
raise
ValueError
(
f
"Modality
{
modality
}
is not in "
f
"limit_mm_per_prompt: "
f
"
{
limit_mm_per_prompt
.
keys
()
}
"
)
# Remove zero probability entries
# and normalize bucket config to sum to 1
bucket_config
=
self
.
normalize_bucket_config
(
bucket_config
)
logger
.
info
(
"Normalized bucket config: %s"
,
bucket_config
,
)
# Only consider limit per prompt for modalities in bucket config
allowed_modalities
=
{
self
.
map_config_to_modality
(
cfg
)
for
cfg
in
bucket_config
}
limit_mm_per_prompt
=
{
k
:
v
for
k
,
v
in
limit_mm_per_prompt
.
items
()
if
k
in
allowed_modalities
}
if
not
limit_mm_per_prompt
:
raise
ValueError
(
"No valid limits for modalities present in "
"bucket_config."
)
logger
.
info
(
"Updated mm-limit-per-prompt: %s"
,
limit_mm_per_prompt
,
)
# Get max and min num mm items and ensure
# it is at most the sum of limit_mm_per_prompt for all modalities
max_num_mm_items
=
min
(
sum
(
limit_mm_per_prompt
.
values
()),
math
.
ceil
(
base_items_per_request
*
(
1
+
num_mm_items_range_ratio
))
)
# Ensure min num mm items is at least 0
min_num_mm_items
=
max
(
0
,
math
.
floor
(
base_items_per_request
*
(
1
-
num_mm_items_range_ratio
))
)
# Raise error if min num mm items is greater than max num mm items
if
min_num_mm_items
>
max_num_mm_items
:
raise
ValueError
(
f
"Min num mm items is greater than max mm items: "
f
"
{
min_num_mm_items
}
>
{
max_num_mm_items
}
"
)
logger
.
info
(
"Sampling number of multimodal items from [%s, %s]"
,
min_num_mm_items
,
max_num_mm_items
,
)
return
(
min_num_mm_items
,
max_num_mm_items
,
limit_mm_per_prompt
,
bucket_config
,
)
def
get_mm_item_iterator
(
self
,
min_num_mm_items
:
int
,
max_num_mm_items
:
int
,
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
],
limit_mm_per_prompt
:
dict
[
str
,
int
],
)
->
Iterator
[
tuple
[
int
,
int
,
int
]]:
"""
Iterator over the multimodal items for each request
whose size is between min_num_mm_items and max_num_mm_items.
Loop over the bucket config and sample a multimodal item.
Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt
for all modalities is reached.
Note:
- This function operates on a per-request shallow copy of
`bucket_config` (tuple->float). The original dict passed to
`sample` is not mutated. If this ever changes, a test
is implemented and will fail.
"""
# Get the number of multimodal items to sample
request_num_mm_items
=
int
(
self
.
_rng
.
integers
(
min_num_mm_items
,
max_num_mm_items
+
1
)
)
# If request_num_mm_items is 0, yield an empty iterator
if
request_num_mm_items
==
0
:
return
# Initialize modality counters
modality_counter
=
{
self
.
map_config_to_modality
(
k
):
0
for
k
in
bucket_config
}
# Copy the bucket config to avoid modifying the original
bucket_config_copy
=
bucket_config
.
copy
()
# Loop over the number of multimodal items to sample
while
sum
(
modality_counter
.
values
())
<
request_num_mm_items
:
# Sample a multimodal item config
mm_item_config
=
self
.
_rng
.
choice
(
list
(
bucket_config_copy
.
keys
()),
p
=
list
(
bucket_config_copy
.
values
()))
modality
=
self
.
map_config_to_modality
(
mm_item_config
)
# Check that modality count is less than limit per prompt
if
modality_counter
[
modality
]
<
limit_mm_per_prompt
[
modality
]:
modality_counter
[
modality
]
+=
1
yield
(
mm_item_config
)
else
:
# If the counter is greater than the limit per prompt
# set all multimodal items of this modality to 0
for
k
,
v
in
bucket_config_copy
.
items
():
if
self
.
map_config_to_modality
(
k
)
==
modality
:
bucket_config_copy
[
k
]
=
0
# If all configs are 0, break the loop
# This should not happen as request_num_mm_items is at most
# the sum of limit_mm_per_prompt for all modalities
if
all
(
v
==
0
for
v
in
bucket_config_copy
.
values
()):
logger
.
warning
(
"Exhausted all multimodal items "
"of modality %s"
,
modality
)
break
# Renormalize the bucket config
bucket_config_copy
=
self
.
normalize_bucket_config
(
bucket_config_copy
)
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
prefix_len
:
int
=
RandomDataset
.
DEFAULT_PREFIX_LEN
,
range_ratio
:
float
=
RandomDataset
.
DEFAULT_RANGE_RATIO
,
input_len
:
int
=
RandomDataset
.
DEFAULT_INPUT_LEN
,
output_len
:
int
=
RandomDataset
.
DEFAULT_OUTPUT_LEN
,
limit_mm_per_prompt
:
dict
[
str
,
int
]
=
DEFAULT_LIMIT_MM_PER_PROMPT
,
base_items_per_request
:
int
=
DEFAULT_BASE_ITEMS_PER_REQUEST
,
num_mm_items_range_ratio
:
float
=
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO
,
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
]
=
DEFAULT_MM_ITEM_BUCKET_CONFIG
,
enable_multimodal_chat
:
bool
=
DEFAULT_ENABLE_MULTIMODAL_CHAT
,
**
kwargs
,
)
->
list
[
SampleRequest
]:
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero.
if
any
(
self
.
map_config_to_modality
(
cfg
)
==
"video"
and
p
>
0
for
cfg
,
p
in
bucket_config
.
items
()):
raise
NotImplementedError
(
"Video sampling not implemented; "
"set its probability to 0."
)
# Get the sampling parameters for the dataset
input_lens
,
output_lens
,
offsets
=
self
.
get_sampling_params
(
num_requests
,
range_ratio
,
input_len
,
output_len
,
tokenizer
)
(
min_num_mm_items
,
max_num_mm_items
,
limit_mm_per_prompt
,
bucket_config
,
)
=
self
.
get_mm_item_sampling_params
(
base_items_per_request
,
num_mm_items_range_ratio
,
limit_mm_per_prompt
,
bucket_config
,
)
# Generate prefix once
prefix_token_ids
=
self
.
get_prefix
(
tokenizer
,
prefix_len
)
vocab_size
=
tokenizer
.
vocab_size
# Add synthetic multimodal items to each request
mm_requests
=
[]
for
i
in
range
(
num_requests
):
prompt
,
total_input_len
=
self
.
generate_token_sequence
(
tokenizer
=
tokenizer
,
prefix_token_ids
=
prefix_token_ids
,
prefix_len
=
prefix_len
,
vocab_size
=
vocab_size
,
input_len
=
int
(
input_lens
[
i
]),
offset
=
int
(
offsets
[
i
]),
index
=
i
,
)
# Get multimodal item iterator for a given request
mm_item_iterator
=
self
.
get_mm_item_iterator
(
min_num_mm_items
,
max_num_mm_items
,
bucket_config
,
limit_mm_per_prompt
,
)
mm_content
=
cast
(
list
[
dict
[
str
,
Any
]],
[
self
.
generate_mm_item
(
mm_item_config
)
for
mm_item_config
in
mm_item_iterator
])
if
enable_multimodal_chat
:
# NOTE: For now this option is only provided for completeness
# given that the serve.py benchmark currently does not use it.
mm_chat_prompt
:
Any
=
prompt
mm_chat_prompt
=
self
.
apply_multimodal_chat_transformation
(
prompt
,
mm_content
)
sample_request
=
SampleRequest
(
prompt
=
mm_chat_prompt
,
prompt_len
=
total_input_len
,
expected_output_len
=
int
(
output_lens
[
i
]),
multi_modal_data
=
None
,
request_id
=
request_id_prefix
+
str
(
i
),
)
else
:
sample_request
=
SampleRequest
(
prompt
=
prompt
,
prompt_len
=
total_input_len
,
expected_output_len
=
int
(
output_lens
[
i
]),
))
return
requests
multi_modal_data
=
mm_content
,
request_id
=
request_id_prefix
+
str
(
i
),
)
mm_requests
.
append
(
sample_request
)
return
mm_requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
...
...
@@ -432,9 +958,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras
:
Optional
[
int
]
=
None
,
output_len
:
Optional
[
int
]
=
None
,
enable_multimodal_chat
:
bool
=
False
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
:
samples
:
list
=
[]
ind
=
0
for
entry
in
self
.
data
:
if
len
(
samples
)
>=
num_requests
:
break
...
...
@@ -455,9 +983,10 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check
=
output_len
is
not
None
):
continue
# TODO: Also support ShareGPT4Video.
if
image_path
:
=
entry
.
get
(
"image"
):
mm_content
=
process_image
(
image_path
)
elif
video_path
:
=
entry
.
get
(
"video"
):
mm_content
=
process_video
(
video_path
)
else
:
mm_content
=
None
if
enable_multimodal_chat
:
...
...
@@ -470,8 +999,10 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len
=
new_output_len
,
lora_request
=
lora_request
,
multi_modal_data
=
mm_content
,
request_id
=
request_id_prefix
+
str
(
ind
),
))
self
.
maybe_oversample_requests
(
samples
,
num_requests
)
ind
+=
1
self
.
maybe_oversample_requests
(
samples
,
num_requests
,
request_id_prefix
)
return
samples
...
...
@@ -488,8 +1019,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
type
=
str
,
default
=
"random"
,
choices
=
[
"sharegpt"
,
"burstgpt"
,
"sonnet"
,
"random"
,
"
hf"
,
"custom
"
,
"prefix_repetition"
"sharegpt"
,
"burstgpt"
,
"sonnet"
,
"random"
,
"
random-mm"
,
"hf
"
,
"custom"
,
"prefix_repetition"
],
help
=
"Name of the dataset to benchmark on."
,
)
...
...
@@ -589,6 +1120,103 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."
),
)
random_group
.
add_argument
(
"--random-batch-size"
,
type
=
int
,
default
=
1
,
help
=
(
"Batch size for random sampling. "
"Only used for embeddings benchmark."
),
)
# random multimodal dataset options
random_mm_group
=
parser
.
add_argument_group
(
"random multimodal dataset options extended from random dataset"
)
random_mm_group
.
add_argument
(
"--random-mm-base-items-per-request"
,
type
=
int
,
default
=
RandomMultiModalDataset
.
DEFAULT_BASE_ITEMS_PER_REQUEST
,
help
=
(
"Base number of multimodal items per request for random-mm. "
"Actual per-request count is sampled around this base using "
"--random-mm-num-mm-items-range-ratio."
),
)
random_mm_group
.
add_argument
(
"--random-mm-num-mm-items-range-ratio"
,
type
=
float
,
default
=
RandomMultiModalDataset
.
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO
,
help
=
(
"Range ratio r in [0, 1] for sampling items per request. "
"We sample uniformly from the closed integer range "
"[floor(n*(1-r)), ceil(n*(1+r))] "
"where n is the base items per request. "
"r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
"to the sum of per-modality limits from "
"--random-mm-limit-mm-per-prompt. "
"An error is raised if the computed min exceeds the max."
),
)
random_mm_group
.
add_argument
(
"--random-mm-limit-mm-per-prompt"
,
type
=
json
.
loads
,
default
=
RandomMultiModalDataset
.
DEFAULT_LIMIT_MM_PER_PROMPT
,
help
=
(
"Per-modality hard caps for items attached per request, e.g. "
"'{
\"
image
\"
: 3,
\"
video
\"
: 0}'. The sampled per-request item "
"count is clamped to the sum of these limits. When a modality "
"reaches its cap, its buckets are excluded and probabilities are "
"renormalized."
"OBS.: Only image sampling is supported for now."
),
)
def
_parse_mm_bucket_config
(
v
:
object
)
->
dict
[
tuple
[
int
,
int
,
int
],
float
]:
# If already a dict (e.g., programmatic call), normalize keys
def
normalize
(
d
:
dict
)
->
dict
[
tuple
[
int
,
int
,
int
],
float
]:
out
:
dict
[
tuple
[
int
,
int
,
int
],
float
]
=
{}
for
k
,
val
in
d
.
items
():
key
=
k
if
isinstance
(
key
,
str
):
with
suppress
(
Exception
):
key
=
ast
.
literal_eval
(
key
)
if
not
(
isinstance
(
key
,
tuple
)
and
len
(
key
)
==
3
and
all
(
isinstance
(
x
,
int
)
for
x
in
key
)):
raise
ValueError
(
f
"Invalid bucket key
{
k
!
r
}
. Expected tuple (H, W, T)."
)
out
[(
int
(
key
[
0
]),
int
(
key
[
1
]),
int
(
key
[
2
]))]
=
float
(
val
)
return
out
if
isinstance
(
v
,
dict
):
return
normalize
(
v
)
if
isinstance
(
v
,
str
):
# Python literal (supports tuple keys)
parsed
=
ast
.
literal_eval
(
v
)
if
not
isinstance
(
parsed
,
dict
):
raise
ValueError
(
"Bucket config must parse to a dict."
)
return
normalize
(
parsed
)
raise
ValueError
(
"Unsupported value for --random-mm-bucket-config."
)
random_mm_group
.
add_argument
(
"--random-mm-bucket-config"
,
type
=
_parse_mm_bucket_config
,
default
=
RandomMultiModalDataset
.
DEFAULT_MM_ITEM_BUCKET_CONFIG
,
help
=
(
"The bucket config is a dictionary mapping a multimodal item"
"sampling configuration to a probability."
"Currently allows for 2 modalities: images and videos. "
"An bucket key is a tuple of (height, width, num_frames)"
"The value is the probability of sampling that specific item. "
"Example: "
"--random-mm-bucket-config "
"{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
"First item: images with resolution 256x256 w.p. 0.5"
"Second item: images with resolution 720x1280 w.p. 0.4 "
"Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
"OBS.: If the probabilities do not sum to 1, they are normalized."
"OBS bis.: Only image sampling is supported for now."
),
)
hf_group
=
parser
.
add_argument_group
(
"hf dataset options"
)
hf_group
.
add_argument
(
"--hf-subset"
,
...
...
@@ -647,6 +1275,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer
=
tokenizer
,
output_len
=
args
.
custom_output_len
,
skip_chat_template
=
args
.
custom_skip_chat_template
,
request_id_prefix
=
args
.
request_id_prefix
,
)
elif
args
.
dataset_name
==
"sonnet"
:
...
...
@@ -660,6 +1289,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len
=
args
.
sonnet_prefix_len
,
tokenizer
=
tokenizer
,
return_prompt_formatted
=
False
,
request_id_prefix
=
args
.
request_id_prefix
,
)
else
:
assert
tokenizer
.
chat_template
or
tokenizer
.
default_chat_template
,
(
...
...
@@ -671,6 +1301,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len
=
args
.
sonnet_prefix_len
,
tokenizer
=
tokenizer
,
return_prompt_formatted
=
True
,
request_id_prefix
=
args
.
request_id_prefix
,
)
elif
args
.
dataset_name
==
"hf"
:
...
...
@@ -716,10 +1347,11 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
"openai-chat"
,
"openai-audio"
,
]:
# multi-modal benchmark is only available on OpenAI Chat backend.
# multi-modal benchmark is only available on OpenAI Chat
# endpoint-type.
raise
ValueError
(
"Multi-modal content is only supported on 'openai-chat' and "
"'openai-audio'
backend
."
)
"'openai-audio'
endpoint-type
."
)
input_requests
=
dataset_class
(
dataset_path
=
args
.
dataset_path
,
dataset_subset
=
args
.
hf_subset
,
...
...
@@ -730,31 +1362,54 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
output_len
=
args
.
hf_output_len
,
request_id_prefix
=
args
.
request_id_prefix
,
)
else
:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping
=
{
"sharegpt"
:
lambda
:
ShareGPTDataset
(
random_seed
=
args
.
seed
,
dataset_path
=
args
.
dataset_path
).
sample
(
"sharegpt"
:
lambda
:
ShareGPTDataset
(
random_seed
=
args
.
seed
,
dataset_path
=
args
.
dataset_path
).
sample
(
tokenizer
=
tokenizer
,
num_requests
=
args
.
num_prompts
,
output_len
=
args
.
sharegpt_output_len
,
request_id_prefix
=
args
.
request_id_prefix
,
),
"burstgpt"
:
lambda
:
BurstGPTDataset
(
random_seed
=
args
.
seed
,
dataset_path
=
args
.
dataset_path
).
sample
(
tokenizer
=
tokenizer
,
num_requests
=
args
.
num_prompts
),
"random"
:
lambda
:
RandomDataset
(
random_seed
=
args
.
seed
,
dataset_path
=
args
.
dataset_path
).
sample
(
"burstgpt"
:
lambda
:
BurstGPTDataset
(
random_seed
=
args
.
seed
,
dataset_path
=
args
.
dataset_path
).
sample
(
tokenizer
=
tokenizer
,
num_requests
=
args
.
num_prompts
,
request_id_prefix
=
args
.
request_id_prefix
,
),
"random"
:
lambda
:
RandomDataset
(
random_seed
=
args
.
seed
,
dataset_path
=
args
.
dataset_path
).
sample
(
tokenizer
=
tokenizer
,
num_requests
=
args
.
num_prompts
,
prefix_len
=
args
.
random_prefix_len
,
input_len
=
args
.
random_input_len
,
output_len
=
args
.
random_output_len
,
range_ratio
=
args
.
random_range_ratio
,
request_id_prefix
=
args
.
request_id_prefix
,
batchsize
=
args
.
random_batch_size
,
),
"random-mm"
:
lambda
:
RandomMultiModalDataset
(
random_seed
=
args
.
seed
,
dataset_path
=
args
.
dataset_path
).
sample
(
tokenizer
=
tokenizer
,
num_requests
=
args
.
num_prompts
,
prefix_len
=
args
.
random_prefix_len
,
range_ratio
=
args
.
random_range_ratio
,
input_len
=
args
.
random_input_len
,
output_len
=
args
.
random_output_len
,
base_items_per_request
=
args
.
random_mm_base_items_per_request
,
limit_mm_per_prompt
=
args
.
random_mm_limit_mm_per_prompt
,
num_mm_items_range_ratio
=
args
.
random_mm_num_mm_items_range_ratio
,
bucket_config
=
args
.
random_mm_bucket_config
,
request_id_prefix
=
args
.
request_id_prefix
,
),
"prefix_repetition"
:
lambda
:
PrefixRepetitionRandomDataset
(
...
...
@@ -766,10 +1421,18 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
suffix_len
=
args
.
prefix_repetition_suffix_len
,
num_prefixes
=
args
.
prefix_repetition_num_prefixes
,
output_len
=
args
.
prefix_repetition_output_len
,
request_id_prefix
=
args
.
request_id_prefix
,
),
}
try
:
# Enforce endpoint compatibility for multimodal datasets.
if
args
.
dataset_name
==
"random-mm"
and
args
.
endpoint_type
not
in
[
"openai-chat"
]:
raise
ValueError
(
"Multi-modal content (images) is only supported on "
"'openai-chat' backend."
)
input_requests
=
dataset_mapping
[
args
.
dataset_name
]()
except
KeyError
as
err
:
raise
ValueError
(
f
"Unknown dataset:
{
args
.
dataset_name
}
"
)
from
err
...
...
@@ -839,10 +1502,11 @@ class CustomDataset(BenchmarkDataset):
output_len
:
Optional
[
int
]
=
None
,
enable_multimodal_chat
:
bool
=
False
,
skip_chat_template
:
bool
=
False
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
:
sampled_requests
=
[]
for
item
in
self
.
data
:
for
i
,
item
in
enumerate
(
self
.
data
)
:
if
len
(
sampled_requests
)
>=
num_requests
:
break
prompt
=
item
[
"prompt"
]
...
...
@@ -864,8 +1528,10 @@ class CustomDataset(BenchmarkDataset):
prompt
=
prompt
,
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
i
),
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -909,6 +1575,7 @@ class SonnetDataset(BenchmarkDataset):
input_len
:
int
=
DEFAULT_INPUT_LEN
,
output_len
:
int
=
DEFAULT_OUTPUT_LEN
,
return_prompt_formatted
:
bool
=
False
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
:
# Calculate average token length for a poem line.
...
...
@@ -934,6 +1601,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines
=
self
.
data
[:
num_prefix_lines
]
samples
=
[]
ind
=
0
while
len
(
samples
)
<
num_requests
:
extra_lines
=
random
.
choices
(
self
.
data
,
k
=
num_input_lines
-
num_prefix_lines
)
...
...
@@ -949,7 +1617,9 @@ class SonnetDataset(BenchmarkDataset):
if
return_prompt_formatted
else
prompt
,
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
ind
),
))
ind
+=
1
return
samples
...
...
@@ -1000,6 +1670,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests
:
int
,
max_loras
:
Optional
[
int
]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
[
SampleRequest
]:
samples
=
[]
...
...
@@ -1020,6 +1691,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len
=
input_len
,
expected_output_len
=
output_len
,
lora_request
=
lora_req
,
request_id
=
request_id_prefix
+
str
(
i
),
))
return
samples
...
...
@@ -1075,11 +1747,13 @@ class ConversationDataset(HuggingFaceDataset):
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
enable_multimodal_chat
:
bool
=
False
,
request_id_prefix
:
str
=
""
,
**
kwargs
)
->
list
:
# Filter examples with at least 2 conversations
filtered_data
=
self
.
data
.
filter
(
lambda
x
:
len
(
x
[
"conversations"
])
>=
2
)
sampled_requests
=
[]
ind
=
0
dynamic_output
=
output_len
is
None
for
item
in
filtered_data
:
...
...
@@ -1111,8 +1785,11 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
multi_modal_data
=
mm_content
,
request_id
=
request_id_prefix
+
str
(
ind
),
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
ind
+=
1
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -1141,12 +1818,13 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
enable_multimodal_chat
:
bool
=
False
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
:
output_len
=
(
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
)
sampled_requests
=
[]
for
item
in
self
.
data
:
for
i
,
item
in
enumerate
(
self
.
data
)
:
if
len
(
sampled_requests
)
>=
num_requests
:
break
parser_fn
=
self
.
SUPPORTED_DATASET_PATHS
.
get
(
self
.
dataset_path
)
...
...
@@ -1168,8 +1846,10 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
multi_modal_data
=
mm_content
,
request_id
=
request_id_prefix
+
str
(
i
),
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -1198,15 +1878,18 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
enable_multimodal_chat
:
bool
=
False
,
request_id_prefix
:
str
=
""
,
**
kwargs
)
->
list
:
output_len
=
(
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
)
sampled_requests
=
[]
for
item
in
self
.
data
:
for
i
,
item
in
enumerate
(
self
.
data
)
:
if
len
(
sampled_requests
)
>=
num_requests
:
break
prompt
=
f
"
{
item
[
'input'
]
}
\n\n
{
item
[
'instruction'
]
}
Just output
\
the code, do not include any explanation."
prompt
=
(
f
"
{
item
[
'input'
]
}
\n\n
{
item
[
'instruction'
]
}
Just output "
"the code, do not include any explanation."
)
# apply template
prompt
=
tokenizer
.
apply_chat_template
(
...
...
@@ -1224,8 +1907,10 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt
=
prompt
,
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
i
),
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -1255,13 +1940,14 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
enable_multimodal_chat
:
bool
=
False
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
:
output_len
=
(
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
)
sampled_requests
=
[]
for
item
in
self
.
data
:
for
i
,
item
in
enumerate
(
self
.
data
)
:
if
len
(
sampled_requests
)
>=
num_requests
:
break
prompt
=
item
[
"turns"
][
0
]
...
...
@@ -1282,8 +1968,10 @@ class MTBenchDataset(HuggingFaceDataset):
prompt
=
prompt
,
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
i
),
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -1305,8 +1993,10 @@ class AIMODataset(HuggingFaceDataset):
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
request_id_prefix
:
str
=
""
,
**
kwargs
)
->
list
:
sampled_requests
=
[]
ind
=
0
dynamic_output
=
output_len
is
None
for
item
in
self
.
data
:
...
...
@@ -1331,8 +2021,12 @@ class AIMODataset(HuggingFaceDataset):
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
multi_modal_data
=
None
,
request_id
=
request_id_prefix
+
str
(
ind
),
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
ind
+=
1
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -1403,13 +2097,14 @@ class NextEditPredictionDataset(HuggingFaceDataset):
}
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
**
kwargs
):
formatting_prompt_func
=
self
.
MAPPING_PROMPT_FUNCS
.
get
(
self
.
dataset_path
)
if
formatting_prompt_func
is
None
:
raise
ValueError
(
f
"Unsupported dataset path:
{
self
.
dataset_path
}
"
)
samples
=
[]
for
sample
in
self
.
data
:
for
i
,
sample
in
enumerate
(
self
.
data
)
:
sample
=
formatting_prompt_func
(
sample
)
samples
.
append
(
SampleRequest
(
...
...
@@ -1417,10 +2112,11 @@ class NextEditPredictionDataset(HuggingFaceDataset):
prompt_len
=
len
(
tokenizer
(
sample
[
"prompt"
]).
input_ids
),
expected_output_len
=
len
(
tokenizer
(
sample
[
"expected_output"
]).
input_ids
),
request_id
=
request_id_prefix
+
str
(
i
),
))
if
len
(
samples
)
>=
num_requests
:
break
self
.
maybe_oversample_requests
(
samples
,
num_requests
)
self
.
maybe_oversample_requests
(
samples
,
num_requests
,
request_id_prefix
)
return
samples
...
...
@@ -1470,6 +2166,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
:
output_len
=
(
output_len
...
...
@@ -1477,6 +2174,7 @@ class ASRDataset(HuggingFaceDataset):
prompt
=
ASRDataset
.
TRANSCRIPTION_PREAMBLE
prompt_len
=
len
(
tokenizer
(
prompt
).
input_ids
)
sampled_requests
=
[]
ind
=
0
skipped
=
0
for
item
in
self
.
data
:
if
len
(
sampled_requests
)
>=
num_requests
:
...
...
@@ -1496,7 +2194,9 @@ class ASRDataset(HuggingFaceDataset):
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
multi_modal_data
=
mm_content
,
request_id
=
request_id_prefix
+
str
(
ind
),
))
ind
+=
1
if
skipped
:
logger
.
warning
(
"%d samples discarded from dataset due to"
...
...
@@ -1504,7 +2204,8 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports."
,
skipped
,
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -1541,11 +2242,13 @@ class MLPerfDataset(HuggingFaceDataset):
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
[
SampleRequest
]:
# Force dynamic output length based on reference completion.
dynamic_output
=
output_len
is
None
sampled_requests
:
list
[
SampleRequest
]
=
[]
ind
=
0
for
item
in
self
.
data
:
if
len
(
sampled_requests
)
>=
num_requests
:
...
...
@@ -1580,10 +2283,13 @@ class MLPerfDataset(HuggingFaceDataset):
prompt
=
prompt_formatted
,
prompt_len
=
prompt_len
,
expected_output_len
=
expected_output_len
,
request_id
=
request_id_prefix
+
str
(
ind
),
)
)
ind
+=
1
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
)
return
sampled_requests
...
...
@@ -1616,6 +2322,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
suffix_len
:
int
=
DEFAULT_SUFFIX_LEN
,
num_prefixes
:
int
=
DEFAULT_NUM_PREFIXES
,
output_len
:
int
=
DEFAULT_OUTPUT_LEN
,
request_id_prefix
:
str
=
""
,
**
kwargs
,
)
->
list
[
SampleRequest
]:
vocab_size
=
tokenizer
.
vocab_size
...
...
vllm/benchmarks/lib/endpoint_request_func.py
View file @
d2b52805
...
...
@@ -9,7 +9,7 @@ import sys
import
time
import
traceback
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
aiohttp
from
tqdm.asyncio
import
tqdm
...
...
@@ -28,9 +28,10 @@ class RequestFuncInput:
model_name
:
Optional
[
str
]
=
None
logprobs
:
Optional
[
int
]
=
None
extra_body
:
Optional
[
dict
]
=
None
multi_modal_content
:
Optional
[
dict
|
list
[
dict
]]
=
None
multi_modal_content
:
Optional
[
Union
[
dict
,
list
[
dict
]]
]
=
None
ignore_eos
:
bool
=
False
language
:
Optional
[
str
]
=
None
request_id
:
Optional
[
str
]
=
None
@
dataclass
...
...
@@ -68,7 +69,7 @@ async def async_request_openai_completions(
),
"OpenAI Completions API URL must end with 'completions' or 'profile'."
payload
=
{
"model"
:
request_func_input
.
model_name
\
"model"
:
request_func_input
.
model_name
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"prompt"
:
request_func_input
.
prompt
,
"temperature"
:
0.0
,
...
...
@@ -87,6 +88,8 @@ async def async_request_openai_completions(
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
if
request_func_input
.
request_id
:
headers
[
"x-request-id"
]
=
request_func_input
.
request_id
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
...
...
@@ -210,6 +213,8 @@ async def async_request_openai_chat_completions(
"Content-Type"
:
"application/json"
,
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
if
request_func_input
.
request_id
:
headers
[
"x-request-id"
]
=
request_func_input
.
request_id
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
...
...
@@ -311,6 +316,8 @@ async def async_request_openai_audio(
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
if
request_func_input
.
request_id
:
headers
[
"x-request-id"
]
=
request_func_input
.
request_id
# Send audio file
def
to_bytes
(
y
,
sr
):
...
...
@@ -387,12 +394,61 @@ async def async_request_openai_audio(
return
output
async
def
async_request_openai_embeddings
(
request_func_input
:
RequestFuncInput
,
session
:
aiohttp
.
ClientSession
,
pbar
:
Optional
[
tqdm
]
=
None
,
):
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"embeddings"
),
"OpenAI Embeddings API URL must end with 'embeddings'."
headers
=
{
"Content-Type"
:
"application/json"
,
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
payload
=
{
"model"
:
request_func_input
.
model
,
"input"
:
request_func_input
.
prompt
,
}
output
=
RequestFuncOutput
()
st
=
time
.
perf_counter
()
try
:
async
with
session
.
post
(
url
=
api_url
,
headers
=
headers
,
json
=
payload
)
as
response
:
if
response
.
status
==
200
:
output
.
latency
=
time
.
perf_counter
()
-
st
data
=
await
response
.
json
()
output
.
success
=
True
output
.
generated_text
=
""
output
.
prompt_len
=
data
.
get
(
"usage"
,
{}).
get
(
"prompt_tokens"
,
0
)
else
:
output
.
success
=
False
output
.
error
=
response
.
reason
or
""
except
Exception
as
e
:
output
.
success
=
False
output
.
error
=
str
(
e
)
if
pbar
:
pbar
.
update
(
1
)
return
output
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS
=
{
"vllm"
:
async_request_openai_completions
,
"openai"
:
async_request_openai_completions
,
"openai-chat"
:
async_request_openai_chat_completions
,
"openai-audio"
:
async_request_openai_audio
,
"openai-embeddings"
:
async_request_openai_embeddings
,
}
OPENAI_COMPATIBLE_BACKENDS
=
[
...
...
vllm/benchmarks/lib/utils.py
View file @
d2b52805
...
...
@@ -54,7 +54,12 @@ class InfEncoder(json.JSONEncoder):
def
clear_inf
(
self
,
o
:
Any
):
if
isinstance
(
o
,
dict
):
return
{
k
:
self
.
clear_inf
(
v
)
for
k
,
v
in
o
.
items
()}
return
{
str
(
k
)
if
not
isinstance
(
k
,
(
str
,
int
,
float
,
bool
,
type
(
None
)))
else
k
:
self
.
clear_inf
(
v
)
for
k
,
v
in
o
.
items
()
}
elif
isinstance
(
o
,
list
):
return
[
self
.
clear_inf
(
v
)
for
v
in
o
]
elif
isinstance
(
o
,
float
)
and
math
.
isinf
(
o
):
...
...
vllm/benchmarks/serve.py
View file @
d2b52805
...
...
@@ -26,6 +26,7 @@ import warnings
from
collections.abc
import
AsyncGenerator
,
Iterable
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
enum
import
Enum
from
typing
import
Any
,
Literal
,
Optional
import
aiohttp
...
...
@@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION
=
1000
class
TaskType
(
Enum
):
GENERATION
=
"generation"
EMBEDDING
=
"embedding"
@
dataclass
class
BenchmarkMetrics
:
completed
:
int
...
...
@@ -75,6 +81,16 @@ class BenchmarkMetrics:
std_e2el_ms
:
float
percentiles_e2el_ms
:
list
[
tuple
[
float
,
float
]]
@
dataclass
class
EmbedBenchmarkMetrics
:
completed
:
int
total_input
:
int
request_throughput
:
float
total_token_throughput
:
float
mean_e2el_ms
:
float
std_e2el_ms
:
float
median_e2el_ms
:
float
percentiles_e2el_ms
:
float
def
_get_current_request_rate
(
ramp_up_strategy
:
Optional
[
Literal
[
"linear"
,
"exponential"
]],
...
...
@@ -189,6 +205,51 @@ async def get_request(
yield
request
,
request_rates
[
request_index
]
def
calculate_metrics_for_embeddings
(
outputs
:
list
[
RequestFuncOutput
],
dur_s
:
float
,
selected_percentiles
:
list
[
float
]
)
->
EmbedBenchmarkMetrics
:
"""Calculate the metrics for the embedding requests.
Args:
outputs: The outputs of the requests.
dur_s: The duration of the benchmark.
selected_percentiles: The percentiles to select.
Returns:
The calculated benchmark metrics.
"""
total_input
=
0
completed
=
0
e2els
:
list
[
float
]
=
[]
for
i
in
range
(
len
(
outputs
)):
if
outputs
[
i
].
success
:
e2els
.
append
(
outputs
[
i
].
latency
)
completed
+=
1
total_input
+=
outputs
[
i
].
prompt_len
if
completed
==
0
:
warnings
.
warn
(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments."
,
stacklevel
=
2
)
metrics
=
EmbedBenchmarkMetrics
(
completed
=
completed
,
total_input
=
total_input
,
request_throughput
=
completed
/
dur_s
,
total_token_throughput
=
total_input
/
dur_s
,
mean_e2el_ms
=
np
.
mean
(
e2els
or
0
)
*
1000
,
std_e2el_ms
=
np
.
std
(
e2els
or
0
)
*
1000
,
median_e2el_ms
=
np
.
median
(
e2els
or
0
)
*
1000
,
percentiles_e2el_ms
=
[
(
p
,
np
.
percentile
(
e2els
or
0
,
p
)
*
1000
)
for
p
in
selected_percentiles
],
)
return
metrics
def
calculate_metrics
(
input_requests
:
list
[
SampleRequest
],
outputs
:
list
[
RequestFuncOutput
],
...
...
@@ -334,7 +395,15 @@ async def benchmark(
ramp_up_end_rps
:
Optional
[
int
]
=
None
,
ready_check_timeout_sec
:
int
=
600
,
):
task_type
=
(
TaskType
.
EMBEDDING
if
api_url
.
endswith
(
"/v1/embeddings"
)
else
TaskType
.
GENERATION
)
if
endpoint_type
in
ASYNC_REQUEST_FUNCS
:
if
task_type
==
TaskType
.
EMBEDDING
:
request_func
=
ASYNC_REQUEST_FUNCS
[
"openai-embeddings"
]
else
:
request_func
=
ASYNC_REQUEST_FUNCS
[
endpoint_type
]
else
:
raise
ValueError
(
f
"Unknown endpoint_type:
{
endpoint_type
}
"
)
...
...
@@ -478,11 +547,12 @@ async def benchmark(
"timestamp"
:
timestamp
})
last_int_rps
=
current_int_rps
prompt
,
prompt_len
,
output_len
,
mm_content
=
(
prompt
,
prompt_len
,
output_len
,
mm_content
,
request_id
=
(
request
.
prompt
,
request
.
prompt_len
,
request
.
expected_output_len
,
request
.
multi_modal_data
,
request
.
request_id
,
)
req_model_id
,
req_model_name
=
model_id
,
model_name
if
lora_modules
:
...
...
@@ -498,7 +568,8 @@ async def benchmark(
logprobs
=
logprobs
,
multi_modal_content
=
mm_content
,
ignore_eos
=
ignore_eos
,
extra_body
=
extra_body
)
extra_body
=
extra_body
,
request_id
=
request_id
,)
tasks
.
append
(
asyncio
.
create_task
(
limited_request_func
(
request_func_input
=
request_func_input
,
...
...
@@ -511,6 +582,7 @@ async def benchmark(
benchmark_duration
=
time
.
perf_counter
()
-
benchmark_start_time
if
task_type
==
TaskType
.
GENERATION
:
metrics
,
actual_output_lens
=
calculate_metrics
(
input_requests
=
input_requests
,
outputs
=
outputs
,
...
...
@@ -519,6 +591,13 @@ async def benchmark(
selected_percentiles
=
selected_percentiles
,
goodput_config_dict
=
goodput_config_dict
,
)
else
:
metrics
=
calculate_metrics_for_embeddings
(
outputs
=
outputs
,
dur_s
=
benchmark_duration
,
selected_percentiles
=
selected_percentiles
,
)
actual_output_lens
=
0
print
(
"{s:{c}^{n}}"
.
format
(
s
=
' Serving Benchmark Result '
,
n
=
50
,
c
=
'='
))
print
(
"{:<40} {:<10}"
.
format
(
"Successful requests:"
,
metrics
.
completed
))
...
...
@@ -527,22 +606,28 @@ async def benchmark(
max_concurrency
))
if
request_rate
!=
float
(
'inf'
):
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request rate configured (RPS):"
,
request_rate
))
request_rate
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Benchmark duration (s):"
,
benchmark_duration
))
print
(
"{:<40} {:<10}"
.
format
(
"Total input tokens:"
,
metrics
.
total_input
))
print
(
"{:<40} {:<10}"
.
format
(
"Total generated tokens:"
,
metrics
.
total_output
))
if
isinstance
(
metrics
,
BenchmarkMetrics
):
print
(
"{:<40} {:<10}"
.
format
(
"Total generated tokens:"
,
metrics
.
total_output
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request throughput (req/s):"
,
metrics
.
request_throughput
))
if
goodput_config_dict
:
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request goodput (req/s):"
,
metrics
.
request_goodput
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Output token throughput (tok/s):"
,
metrics
.
output_throughput
))
if
isinstance
(
metrics
,
BenchmarkMetrics
):
print
(
"{:<40} {:<10.2f}"
.
format
(
"Output token throughput (tok/s):"
,
metrics
.
output_throughput
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Total Token throughput (tok/s):"
,
metrics
.
total_token_throughput
))
if
isinstance
(
metrics
,
BenchmarkMetrics
):
result
=
{
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
...
...
@@ -560,6 +645,16 @@ async def benchmark(
"generated_texts"
:
[
output
.
generated_text
for
output
in
outputs
],
"errors"
:
[
output
.
error
for
output
in
outputs
],
}
else
:
result
=
{
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
"total_input_tokens"
:
metrics
.
total_input
,
"request_throughput"
:
metrics
.
request_throughput
,
"total_token_throughput"
:
metrics
.
total_token_throughput
,
"input_lens"
:
[
output
.
prompt_len
for
output
in
outputs
],
"errors"
:
[
output
.
error
for
output
in
outputs
],
}
if
rps_change_events
:
result
[
"rps_change_events"
]
=
rps_change_events
...
...
@@ -596,9 +691,10 @@ async def benchmark(
value
))
result
[
f
"p
{
p_word
}
_
{
metric_attribute_name
}
_ms"
]
=
value
if
task_type
==
TaskType
.
GENERATION
:
process_one_metric
(
"ttft"
,
"TTFT"
,
"Time to First Token"
)
process_one_metric
(
"tpot"
,
"TPOT"
,
"Time per Output Token (excl. 1st token)"
)
process_one_metric
(
"tpot"
,
"TPOT"
,
"Time per Output Token (excl. 1st token)"
)
process_one_metric
(
"itl"
,
"ITL"
,
"Inter-token Latency"
)
process_one_metric
(
"e2el"
,
"E2EL"
,
"End-to-end Latency"
)
...
...
@@ -730,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up."
)
"if the server is not processing requests fast enough to keep up."
,
)
parser
.
add_argument
(
"--model"
,
...
...
@@ -741,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
...
...
@@ -865,6 +961,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve"
,
)
parser
.
add_argument
(
"--request-id-prefix"
,
type
=
str
,
required
=
False
,
default
=
"benchmark-serving"
,
help
=
"Specify the prefix of request id."
,
)
sampling_group
=
parser
.
add_argument_group
(
"sampling parameters"
)
sampling_group
.
add_argument
(
...
...
@@ -958,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
def
main
(
args
:
argparse
.
Namespace
)
->
dict
[
str
,
Any
]:
return
asyncio
.
run
(
main_async
(
args
))
async
def
main_async
(
args
:
argparse
.
Namespace
)
->
dict
[
str
,
Any
]:
print
(
args
)
random
.
seed
(
args
.
seed
)
...
...
vllm/benchmarks/throughput.py
View file @
d2b52805
...
...
@@ -435,6 +435,14 @@ def validate_args(args):
raise
ValueError
(
"Tokenizer must be the same as the model for MII backend."
)
# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if
args
.
data_parallel_size
>
1
:
raise
ValueError
(
"Data parallel is not supported in offline benchmark, "
"please use benchmark serving instead"
)
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--backend"
,
...
...
vllm/compilation/activation_quant_fusion.py
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
(
PatternMatcherPass
,
fwd_only
,
register_replacement
)
from
torch._ops
import
OpOverload
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
kStaticTensorScale
)
from
vllm.platforms
import
current_platform
from
.fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
def
silu_mul_pattern_static
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
torch
.
ops
.
_C
.
silu_and_mul
.
default
,
SILU_MUL_OP
=
torch
.
ops
.
_C
.
silu_and_mul
.
default
FUSED_OPS
:
dict
[
QuantKey
,
OpOverload
]
=
{
kFp8StaticTensorSym
:
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
,
# noqa: E501
}
silu_and_mul_nvfp4_quant_supported
=
(
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"silu_and_mul_nvfp4_quant"
))
if
silu_and_mul_nvfp4_quant_supported
:
FUSED_OPS
[
kNvfp4Quant
]
=
torch
.
ops
.
_C
.
silu_and_mul_nvfp4_quant
.
default
# noqa: E501
class
ActivationQuantPattern
(
ABC
):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def
__init__
(
self
,
quant_key
:
QuantKey
,
):
self
.
quant_key
=
quant_key
self
.
quant_dtype
=
quant_key
.
dtype
assert
self
.
quant_key
in
QUANT_OPS
,
\
f
"unsupported quantization scheme
{
self
.
quant_key
}
"
self
.
QUANT_OP
=
QUANT_OPS
[
self
.
quant_key
]
assert
self
.
quant_key
in
FUSED_OPS
,
\
f
"unsupported fusion scheme
{
self
.
quant_key
}
"
self
.
FUSED_OP
=
FUSED_OPS
[
self
.
quant_key
]
def
empty_quant
(
self
,
*
args
,
**
kwargs
):
kwargs
=
{
'dtype'
:
self
.
quant_dtype
,
'device'
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
abstractmethod
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
raise
NotImplementedError
class
SiluMulFp8StaticQuantPattern
(
ActivationQuantPattern
):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def
__init__
(
self
,
symmetric
:
bool
=
True
):
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
)
super
().
__init__
(
quant_key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
SILU_MUL_OP
,
result
=
result_silu_mul
,
input
=
input
)
at2
=
auto_functionalized
(
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
result
=
result
,
input
=
at1
[
1
],
scale
=
scale
)
return
at2
[
1
]
def
silu_mul_replacement_static
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
def
replacement
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at
=
auto_functionalized
(
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
,
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
input
=
input
,
scale
=
scale
)
return
at
[
1
]
inputs
=
[
self
.
empty_quant
(
5
,
4
),
# result
empty_bf16
(
5
,
4
),
# result_silu_mul
empty_bf16
(
5
,
4
),
# input
empty_fp32
(
1
,
1
)
# scale
]
register_replacement
(
pattern
,
replacement
,
inputs
,
fwd_only
,
pm_pass
)
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def
empty_bf16
(
*
args
,
**
kwargs
):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
def
__init__
(
self
):
super
().
__init__
(
kNvfp4Quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
SILU_MUL_OP
,
result
=
result_silu_mul
,
input
=
input
)
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
output
=
result
,
input
=
at1
[
1
],
output_scale
=
output_scale
,
input_scale
=
scale
)
return
at2
[
1
],
at2
[
2
]
def
empty_fp8
(
*
args
,
**
kwargs
):
fp8
=
current_platform
.
fp8_dtype
()
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
fp8
,
device
=
"cuda"
)
def
replacement
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
result_block_scale
=
output_scale
,
input
=
input
,
input_global_scale
=
scale
)
return
at
[
1
],
at
[
2
]
inputs
=
[
self
.
empty_quant
(
5
,
32
),
# result
empty_i32
(
128
,
4
),
# output_scale
empty_bf16
(
5
,
64
),
# result_silu_mul
empty_bf16
(
5
,
64
),
# input
empty_fp32
(
1
,
1
)
# scale
]
def
empty_fp32
(
*
args
,
**
kwargs
):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
register_replacement
(
pattern
,
replacement
,
inputs
,
fwd_only
,
pm_pass
)
class
ActivationQuantFusionPass
(
VllmInductorPass
):
...
...
@@ -61,21 +162,19 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"activation_quant_fusion_pass"
)
inputs
=
[
empty_fp8
(
5
,
4
),
# Quant output
empty_bf16
(
5
,
4
),
# Silu_and_mul output
empty_bf16
(
5
,
4
),
# Input
empty_fp32
(
1
,
1
)
# Scale
]
register_replacement
(
silu_mul_pattern_static
,
silu_mul_replacement_static
,
inputs
,
fwd_only
,
self
.
patterns
)
pattern_silu_mul_fp8
=
SiluMulFp8StaticQuantPattern
()
pattern_silu_mul_fp8
.
register
(
self
.
patterns
)
if
silu_and_mul_nvfp4_quant_supported
:
pattern_silu_mul_nvfp4
=
SiluMulNvfp4QuantPattern
()
pattern_silu_mul_nvfp4
.
register
(
self
.
patterns
)
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
begin
()
...
...
@@ -87,3 +186,8 @@ class ActivationQuantFusionPass(VllmInductorPass):
self
.
dump_graph
(
graph
,
"after_act_quant_fusion"
)
self
.
end_and_log
()
def
uuid
(
self
):
return
VllmInductorPass
.
hash_source
(
self
,
ActivationQuantPattern
,
SiluMulFp8StaticQuantPattern
,
SiluMulNvfp4QuantPattern
)
vllm/compilation/backends.py
View file @
d2b52805
...
...
@@ -271,7 +271,7 @@ def split_graph(graph: fx.GraphModule,
outputs
.
append
(
SplitItem
(
name
,
graph_id
,
(
graph_id
in
split_op_graphs
),
module
))
# sort by inte
t
ger graph_id, rather than string name
# sort by integer graph_id, rather than string name
outputs
.
sort
(
key
=
lambda
x
:
x
.
graph_id
)
return
split_gm
,
outputs
...
...
@@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
compile_submod_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
graph_pool
,
vllm_backend
:
"VllmBackend"
):
vllm_backend
:
"VllmBackend"
):
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
self
.
fake_mode
=
detect_fake_mode
()
self
.
compile_submod_names
=
compile_submod_names
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
graph_pool
=
graph_pool
self
.
vllm_config
=
vllm_config
self
.
vllm_backend
=
vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
...
...
@@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
runnable
=
piecewise_backend
,
vllm_config
=
self
.
vllm_config
,
runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
graph_pool
=
self
.
graph_pool
,
cudagraph_options
=
CUDAGraphOptions
(
debug_log_enable
=
piecewise_backend
.
is_first_graph
,
gc_disable
=
not
piecewise_backend
.
is_first_graph
,
...
...
@@ -405,7 +403,6 @@ class VllmBackend:
vllm_config
:
VllmConfig
compilation_config
:
CompilationConfig
graph_pool
:
Any
_called
:
bool
=
False
# the graph we compiled
graph
:
fx
.
GraphModule
...
...
@@ -427,19 +424,12 @@ class VllmBackend:
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. la
u
nguage_model, vision_model, etc.
# e.g. language_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self
.
prefix
=
prefix
or
model_tag
global_graph_pool
=
current_platform
.
get_global_graph_pool
()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self
.
graph_pool
=
global_graph_pool
# Passes to run on the graph post-grad.
self
.
post_grad_pass_manager
=
PostGradPassManager
()
...
...
@@ -484,7 +474,7 @@ class VllmBackend:
factors
=
[]
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect
s
the computation graph.
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash
=
envs
.
compute_hash
()
factors
.
append
(
env_hash
)
...
...
@@ -586,7 +576,7 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
.
graph_pool
,
self
.
vllm_config
,
self
).
run
(
*
example_inputs
)
graph_path
=
os
.
path
.
join
(
local_cache_dir
,
"computation_graph.py"
)
...
...
vllm/compilation/base_static_graph.py
View file @
d2b52805
...
...
@@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
"""
def
__init__
(
self
,
runnable
:
Callable
,
vllm_config
:
VllmConfig
,
runtime_mode
:
CUDAGraphMode
,
graph_pool
:
Any
,
**
kwargs
):
runtime_mode
:
CUDAGraphMode
,
**
kwargs
):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
...
...
@@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol):
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
...
...
vllm/compilation/collective_fusion.py
View file @
d2b52805
...
...
@@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
torch.distributed._symmetric_memory
import
enable_symm_mem_for_group
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tp_group
,
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
...
...
@@ -18,6 +19,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -348,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class
AsyncTPPass
(
VllmInductorPass
):
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
@@ -401,6 +404,18 @@ if flashinfer_comm is not None:
6
:
MiB
//
2
,
# 512KB
8
:
MiB
//
2
,
# 512KB
}
try
:
_FI_MAX_SIZES
.
update
({
int
(
k
):
int
(
float
(
v
)
*
MiB
)
for
k
,
v
in
envs
.
VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB
.
items
()
})
except
Exception
as
e
:
raise
ValueError
(
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
+
str
(
e
))
from
e
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE
=
MiB
//
2
...
...
@@ -465,7 +480,8 @@ if flashinfer_comm is not None:
quant_out
=
quant_out
,
scale_out
=
scale_out
,
# in vllm we only support swizzled layout
layout_code
=
flashinfer_comm
.
FP4QuantizationSFLayout
.
SWIZZLED
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
scale_factor
,
)
else
:
...
...
@@ -1107,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
# in fallback path, when we don't use flashinfer
fuse_rms_quant
=
config
.
compilation_config
.
pass_config
.
enable_fusion
)
self
.
register_patterns
()
@
enable_fake_mode
def
register_patterns
(
self
):
for
epsilon
in
[
1e-5
,
1e-6
]:
AllReduceFusedRMSNormStaticQuantFP8Pattern
(
epsilon
,
...
...
vllm/compilation/cuda_graph.py
View file @
d2b52805
...
...
@@ -67,11 +67,9 @@ class CUDAGraphWrapper:
runnable
:
Callable
,
vllm_config
:
VllmConfig
,
runtime_mode
:
CUDAGraphMode
,
graph_pool
:
Any
=
None
,
cudagraph_options
:
Optional
[
CUDAGraphOptions
]
=
None
):
self
.
runnable
=
runnable
self
.
vllm_config
=
vllm_config
self
.
graph_pool
=
graph_pool
self
.
runtime_mode
=
runtime_mode
self
.
compilation_config
=
vllm_config
.
compilation_config
...
...
@@ -81,7 +79,9 @@ class CUDAGraphWrapper:
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert
self
.
runtime_mode
!=
CUDAGraphMode
.
NONE
if
self
.
graph_pool
is
None
:
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self
.
graph_pool
=
current_platform
.
get_global_graph_pool
()
if
cudagraph_options
is
None
:
...
...
vllm/compilation/decorators.py
View file @
d2b52805
...
...
@@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool:
return
getattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
@
overload
def
support_torch_compile
(
*
,
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
Callable
[[
_T
],
_T
]:
...
@
overload
def
support_torch_compile
(
*
,
...
...
@@ -69,6 +77,7 @@ def support_torch_compile(
cls
:
Optional
[
_T
]
=
None
,
*
,
dynamic_arg_dims
:
Optional
[
dict
[
str
,
Union
[
int
,
list
[
int
]]]]
=
None
,
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
Union
[
Callable
[[
_T
],
_T
],
_T
]:
"""
A decorator to add support for compiling the forward method of a class.
...
...
@@ -118,6 +127,11 @@ def support_torch_compile(
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
`enable_if` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
"""
def
cls_decorator_helper
(
cls
:
_T
)
->
_T
:
...
...
@@ -149,7 +163,8 @@ def support_torch_compile(
if
k
not
in
sig
.
parameters
:
raise
ValueError
(
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
,
enable_if
)
if
cls
is
not
None
:
# use `support_torch_compile` as a decorator without arguments
...
...
@@ -162,6 +177,7 @@ def support_torch_compile(
def
_support_torch_compile
(
cls
:
_T
,
dynamic_arg_dims
:
dict
[
str
,
Union
[
int
,
list
[
int
]]],
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
_T
:
"""
A decorator to add support for compiling the forward method of a class.
...
...
@@ -182,13 +198,14 @@ def _support_torch_compile(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
vllm_config
=
vllm_config
enable_compile
=
enable_if
is
None
or
enable_if
(
vllm_config
)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self
.
do_not_compile
=
\
vllm_config
.
compilation_config
.
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
()
or
_should_ignore_torch_compile
(
self
.
__class__
)
self
.
__class__
)
or
not
enable_compile
if
self
.
do_not_compile
:
return
...
...
@@ -267,8 +284,24 @@ def _support_torch_compile(
code
.
co_filename
)
return
inline_call
(
parent
,
func
,
args
,
kwargs
)
# Disable the C++ compilation of symbolic shape guards. C++-fication
# of symbolic shape guards can improve guard overhead. But, since
# vllm skip guards anyways, setting this flag to False can improve
# compile time.
dynamo_config_patches
=
{}
try
:
_
=
torch
.
_dynamo
.
config
.
enable_cpp_symbolic_shape_guards
dynamo_config_patches
[
"enable_cpp_symbolic_shape_guards"
]
=
False
except
AttributeError
:
# Note: this config is not available in torch 2.6, we can skip
# if the config doesn't exist
logger
.
debug
(
"enable_cpp_symbolic_shape_guards config not available"
)
with
patch
.
object
(
InliningInstructionTranslator
,
'inline_call'
,
patched_inline_call
):
patched_inline_call
),
torch
.
_dynamo
.
config
.
patch
(
**
dynamo_config_patches
):
output
=
self
.
compiled_callable
(
*
args
,
**
kwargs
)
return
output
...
...
vllm/compilation/fix_functionalization.py
View file @
d2b52805
...
...
@@ -9,6 +9,7 @@ import torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.fx_utils
import
is_func
from
.vllm_inductor_pass
import
VllmInductorPass
...
...
@@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if
current_platform
.
is_xpu
():
logger
.
debug
(
"XPU platform does not support fix functionalization"
"pass currently."
)
return
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_fix_functionalization"
)
...
...
@@ -89,6 +97,15 @@ class FixFunctionalizationPass(VllmInductorPass):
# node,
# mutated_args,
# args=('result', 'input', 'scale'))
# elif hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant"
# ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
# mutated_args = {1: 'result', 2: 'result_block_scale'}
# self.defunctionalize(graph,
# node,
# mutated_args,
# args=('result', 'result_block_scale',
# 'input', 'input_global_scale'))
else
:
continue
# skip the count
...
...
vllm/compilation/fusion.py
View file @
d2b52805
...
...
@@ -12,15 +12,18 @@ from torch._ops import OpOverload
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
GroupShape
,
QuantKey
,
ScaleDesc
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
kStaticTensorScale
)
from
vllm.platforms
import
current_platform
from
.fx_utils
import
find_getitem_maybe
from
.inductor_pass
import
enable_fake_mode
from
.multi_output_match
import
MultiOutputMatch
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
def
empty_bf16
(
*
args
,
**
kwargs
):
...
...
@@ -31,41 +34,13 @@ def empty_fp32(*args, **kwargs):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
class
QuantKey
(
NamedTuple
):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
"""
dtype
:
torch
.
dtype
static
:
bool
group_shape
:
GroupShape
symmetric
:
bool
=
True
def
empty_i32
(
*
args
,
**
kwargs
):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
__str__
(
self
):
group_shape
=
(
'per_tensor'
if
self
.
group_shape
==
GroupShape
.
PER_TENSOR
else
(
'per_token'
if
self
.
group_shape
==
GroupShape
.
PER_TOKEN
else
str
(
self
.
group_shape
)))
return
(
f
"QuantKey(
{
'static'
if
self
.
static
else
'dynamic'
}
,"
f
"
{
fx
.
graph
.
dtype_abbrs
[
self
.
dtype
]
}
,
{
group_shape
}
,"
f
"
{
'a'
if
not
self
.
symmetric
else
''
}
symmetric)"
)
# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
QUANT_OPS
:
dict
[
QuantKey
,
OpOverload
]
=
{
# kFp8StaticTensorSym:
...
...
@@ -75,6 +50,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8DynamicTokenSym:
# torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
QUANT_OPS
[
kNvfp4Quant
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
# noqa: E501
class
FusedRMSQuantKey
(
NamedTuple
):
...
...
@@ -187,10 +165,8 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
fused_key
=
FusedRMSQuantKey
(
fused_add
=
False
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
fused_key
)
...
...
@@ -244,10 +220,8 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
key
)
...
...
@@ -337,10 +311,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
False
,
group_shape
=
group_shape
,
scale
=
scale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
key
)
...
...
@@ -435,10 +409,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
False
,
group_shape
=
group_shape
,
scale
=
scale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
key
)
...
...
@@ -556,6 +530,7 @@ class FusionPass(VllmInductorPass):
cls
.
_instance
.
pass_config
=
config
.
compilation_config
.
pass_config
return
cls
.
_instance
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
assert
self
.
__class__
.
_instance
is
None
,
\
"FusionPass singleton instance already exists"
...
...
vllm/compilation/fusion_attn.py
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch._inductor.pattern_matcher
as
pm
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
torch._subclasses.fake_tensor
import
(
FakeTensorMode
,
unset_fake_temporarily
)
from
vllm.attention
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kNvfp4Quant
,
kStaticTensorScale
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
round_up
from
.fusion
import
QUANT_OPS
,
GroupShape
,
QuantKey
,
empty_bf16
,
empty_fp32
from
.fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
ATTN_OP
=
torch
.
ops
.
vllm
.
unified_attention_with_output
.
default
RESHAPE_OP
=
torch
.
ops
.
aten
.
reshape
.
default
class
AttentionStaticQuantPattern
:
class
AttentionQuantPattern
(
ABC
):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def
__init__
(
self
,
layer_name
:
str
,
num_heads
:
int
,
head_size
:
int
,
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
,
layer
:
Attention
,
quant_key
:
QuantKey
,
):
self
.
layer_name
=
layer_name
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
quant_dtype
=
quant_dtype
self
.
quant_key
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
symmetric
=
symmetric
)
self
.
layer
=
layer
self
.
layer_name
=
layer
.
layer_name
self
.
num_heads
=
layer
.
num_heads
self
.
head_size
=
layer
.
head_size
self
.
quant_key
=
quant_key
self
.
quant_dtype
=
quant_key
.
dtype
assert
self
.
quant_key
in
QUANT_OPS
,
\
f
"unsupported quantization scheme
{
self
.
quant_key
}
"
self
.
QUANT_OP
=
QUANT_OPS
[
self
.
quant_key
]
...
...
@@ -48,31 +55,64 @@ class AttentionStaticQuantPattern:
kwargs
=
{
'dtype'
:
self
.
quant_dtype
,
'device'
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
,
layer
:
Attention
):
if
layer
.
impl
.
fused_output_quant_supported
(
self
.
quant_dtype
,
self
.
quant_key
.
static
,
self
.
quant_key
.
group_shape
):
@
staticmethod
def
wrap_trace_fn
(
process_fx
,
trace_fn
):
def
wrapped
(
*
args
,
**
kwargs
):
return
process_fx
(
trace_fn
(
*
args
,
**
kwargs
))
return
wrapped
@
staticmethod
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
return
gm
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
):
if
self
.
layer
.
impl
.
fused_output_quant_supported
(
self
.
quant_key
):
self
.
_register
(
pm_pass
)
@
abstractmethod
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
raise
NotImplementedError
class
AttentionFp8StaticQuantPattern
(
AttentionQuantPattern
):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def
__init__
(
self
,
layer
:
Attention
,
symmetric
:
bool
=
True
,
):
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
)
super
().
__init__
(
layer
,
quant_key
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
view_7
=
RESHAPE_OP
(
output_attn
,
[
-
1
,
self
.
num_heads
,
self
.
head_size
])
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
view_7
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
None
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
])
output_scale
=
None
,
output_block_scale
=
None
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
])
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
result
=
output_quant
,
input
=
attn_out_view
,
...
...
@@ -82,47 +122,116 @@ class AttentionStaticQuantPattern:
def
replacement
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
view_7
=
RESHAPE_OP
(
output_quant
,
[
-
1
,
self
.
num_heads
,
self
.
head_size
])
# attn output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
],
0.0
,
dtype
=
self
.
quant_dtype
,
device
=
q
.
device
)
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
view_7
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
scale
)
output_scale
=
scale
,
output_block_scale
=
None
)
return
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with
unset_fake_temporarily
(),
FakeTensorMode
():
inputs
=
[
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# q
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# k
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
empty_bf16
(
5
,
self
.
num_heads
*
self
.
head_size
),
# attn_output
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
),
# quant_output
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# attn_output
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
),
# quant_output
empty_fp32
(
1
,
1
)
# scale
]
def
wrap_trace_fn
(
process_fx
,
trace_fn
):
pm
.
register_replacement
(
pattern
,
replacement
,
inputs
,
AttentionQuantPattern
.
wrap_trace_fn
(
AttentionQuantPattern
.
fx_view_to_reshape
,
pm
.
fwd_only
),
pm_pass
)
def
wrapped
(
*
args
,
**
kwargs
):
return
process_fx
(
trace_fn
(
*
args
,
**
kwargs
))
return
wrapped
class
AttentionNvfp4QuantPattern
(
AttentionQuantPattern
):
"""
Fusion for Attention+Nvfp4Quant.
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
return
gm
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def
__init__
(
self
,
layer
:
Attention
):
super
().
__init__
(
layer
,
kNvfp4Quant
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
None
,
output_block_scale
=
None
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
])
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
output
=
output_quant
,
input
=
attn_out_view
,
output_scale
=
output_scale
,
input_scale
=
input_scale
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
at2
[
1
],
output_scale_view
def
replacement
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
):
# attention output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
//
2
],
0.0
,
dtype
=
self
.
quant_dtype
,
device
=
q
.
device
)
# attention output block scale
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
output_scale
,
FP8_DTYPE
)
at2
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
input_scale
,
output_block_scale
=
output_scale_view
)
output
=
RESHAPE_OP
(
at2
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
//
2
])
return
output
,
at2
[
2
]
inputs
=
[
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# q
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# k
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# output_attn
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
//
2
),
# output_quant
empty_i32
(
128
,
round_up
(
self
.
num_heads
*
self
.
head_size
//
16
,
4
)),
# output_scale
empty_fp32
(
1
,
1
),
# input_scale
]
pm
.
register_replacement
(
pattern
,
replacement
,
inputs
,
wrap_trace_fn
(
fx_view_to_reshape
,
pm
.
fwd_only
),
pm_pass
)
AttentionQuantPattern
.
wrap_trace_fn
(
AttentionQuantPattern
.
fx_view_to_reshape
,
pm
.
fwd_only
),
pm_pass
)
class
AttnFusionPass
(
VllmInductorPass
):
...
...
@@ -138,32 +247,42 @@ class AttnFusionPass(VllmInductorPass):
support are attention kernels, which need to support fusing output quant.
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
self
.
static_fwd_ctx
=
config
.
compilation_config
.
static_forward_context
self
.
patterns
=
PatternMatcherPass
(
pass_name
=
"attn_fusion_pass"
)
for
key
,
layer
in
self
.
static_fwd_ctx
.
items
():
pattern
=
AttentionStaticQuantPattern
(
key
,
layer
.
num_heads
,
layer
.
head_size
,
current_platform
.
fp8_dtype
())
pattern
.
register_if_supported
(
self
.
patterns
,
layer
)
if
len
(
self
.
static_fwd_ctx
)
==
0
:
attn_layers
=
get_layers_from_vllm_config
(
config
,
Attention
)
for
layer_name
,
layer
in
attn_layers
.
items
():
pattern_fp8
=
AttentionFp8StaticQuantPattern
(
layer
)
pattern_fp8
.
register_if_supported
(
self
.
patterns
)
pattern_nvfp4
=
AttentionNvfp4QuantPattern
(
layer
)
pattern_nvfp4
.
register_if_supported
(
self
.
patterns
)
if
len
(
attn_layers
)
==
0
:
logger
.
warning
(
"Attention + quant fusion is enabled, but "
"CompilationConfig.static_forward_context is empty. "
"Cannot access attention layers so no fusion "
"patterns were registered."
)
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered."
)
def
__call__
(
self
,
graph
:
torch
.
fx
.
graph
.
Graph
)
->
None
:
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_attn_fusion"
)
count
=
self
.
patterns
.
apply
(
graph
)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph
.
eliminate_dead_code
()
logger
.
debug
(
"Fused quantization onto %s attention nodes"
,
count
)
self
.
dump_graph
(
graph
,
"after_attn_fusion"
)
self
.
end_and_log
()
def
uuid
(
self
):
return
VllmInductorPass
.
hash_source
(
self
,
AttentionStaticQuantPattern
)
return
VllmInductorPass
.
hash_source
(
self
,
AttentionQuantPattern
,
AttentionFp8StaticQuantPattern
,
AttentionNvfp4QuantPattern
)
vllm/compilation/inductor_pass.py
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
hashlib
import
inspect
import
json
...
...
@@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import
torch
from
torch
import
fx
from
torch._subclasses.fake_tensor
import
(
FakeTensorMode
,
unset_fake_temporarily
)
from
vllm.utils
import
is_torch_equal_or_newer
...
...
@@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
def
enable_fake_mode
(
fn
:
Callable
[...,
Any
])
->
Callable
[...,
Any
]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@
functools
.
wraps
(
fn
)
def
fn_new
(
*
args
,
**
kwargs
)
->
Any
:
with
torch
.
_guards
.
tracing
(
None
),
unset_fake_temporarily
(),
FakeTensorMode
():
result
=
fn
(
*
args
,
**
kwargs
)
return
result
return
fn_new
vllm/compilation/monitor.py
View file @
d2b52805
...
...
@@ -43,7 +43,7 @@ cudagraph_capturing_enabled: bool = True
def
validate_cudagraph_capturing_enabled
():
# used to monitor whether a
n
cudagraph capturing is legal at runtime.
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global
cudagraph_capturing_enabled
...
...
vllm/compilation/pass_manager.py
View file @
d2b52805
...
...
@@ -8,13 +8,13 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
from
.activation_quant_fusion
import
ActivationQuantFusionPass
from
.fusion
import
FusionPass
from
.fusion_attn
import
AttnFusionPass
if
current_platform
.
is_cuda
():
from
.collective_fusion
import
AllReduceFusionPass
,
AsyncTPPass
from
.activation_quant_fusion
import
ActivationQuantFusionPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.inductor_pass
import
CustomGraphPass
,
InductorPass
,
get_pass_context
from
.noop_elimination
import
NoOpEliminationPass
...
...
vllm/compilation/sequence_parallelism.py
View file @
d2b52805
...
...
@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
...
...
@@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
performance.
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
vllm/config/__init__.py
View file @
d2b52805
...
...
@@ -33,7 +33,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo
)
from
vllm.config.compilation
import
(
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
PassConfig
)
from
vllm.config.parallel
import
DistributedExecutorBackend
,
ParallelConfig
from
vllm.config.parallel
import
(
DistributedExecutorBackend
,
EPLBConfig
,
ParallelConfig
)
from
vllm.config.scheduler
import
SchedulerConfig
,
SchedulerPolicy
from
vllm.config.utils
import
ConfigType
,
config
from
vllm.logger
import
init_logger
...
...
@@ -191,6 +192,16 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
yield
a
,
b
a
=
b
try
:
cls_node
=
ast
.
parse
(
textwrap
.
dedent
(
inspect
.
getsource
(
cls
))).
body
[
0
]
except
(
OSError
,
KeyError
,
TypeError
):
# HACK: Python 3.13+ workaround - set missing __firstlineno__
# Workaround can be removed after we upgrade to pydantic==2.12.0
with
open
(
inspect
.
getfile
(
cls
))
as
f
:
for
i
,
line
in
enumerate
(
f
):
if
f
"class
{
cls
.
__name__
}
"
in
line
and
":"
in
line
:
cls
.
__firstlineno__
=
i
+
1
break
cls_node
=
ast
.
parse
(
textwrap
.
dedent
(
inspect
.
getsource
(
cls
))).
body
[
0
]
if
not
isinstance
(
cls_node
,
ast
.
ClassDef
):
...
...
@@ -246,8 +257,14 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode
=
Literal
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
LogprobsMode
=
Literal
[
"raw_logprobs"
,
"raw_logits"
,
"processed_logprobs"
,
"processed_logits"
]
MMEncoderTPMode
=
Literal
[
"weights"
,
"data"
]
class
LogprobsMode
(
enum
.
Enum
):
RAW_LOGITS
=
"raw_logits"
RAW_LOGPROBS
=
"raw_logprobs"
PROCESSED_LOGITS
=
"processed_logits"
PROCESSED_LOGPROBS
=
"processed_logprobs"
@
config
...
...
@@ -351,12 +368,13 @@ class ModelConfig:
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode
:
LogprobsMode
=
"raw_l
ogprobs
"
logprobs_mode
:
LogprobsMode
=
L
ogprobs
Mode
.
RAW_LOGPROBS
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
Raw means the values before applying any logit processors, like bad words.
Processed means the values after applying all processors, including
temperature and top_k/top_p.
"""
disable_sliding_window
:
bool
=
False
"""Whether to disable sliding window. If True, we will disable the sliding
...
...
@@ -419,7 +437,7 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
"""
mm_processor_cache_gb
:
in
t
=
4
mm_processor_cache_gb
:
floa
t
=
4
"""The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
...
...
@@ -428,6 +446,19 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
mm_encoder_tp_mode
:
MMEncoderTPMode
=
"weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
...
...
@@ -470,6 +501,8 @@ class ModelConfig:
logits_processors
:
Optional
[
list
[
Union
[
str
,
type
[
LogitsProcessor
]]]]
=
None
"""One or more logits processors' fully-qualified class names or class
definitions"""
io_processor_plugin
:
Optional
[
str
]
=
None
"""IOProcessor plugin name to load at model startup"""
enable_chunked_prefill
:
Optional
[
bool
]
=
None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
...
...
@@ -845,22 +878,25 @@ class ModelConfig:
def
_init_multimodal_config
(
self
)
->
Optional
[
"MultiModalConfig"
]:
if
self
.
_model_info
.
supports_multimodal
:
if
(
self
.
mm_encoder_tp_mode
==
"data"
and
not
self
.
_model_info
.
supports_multimodal_encoder_tp_data
):
logger
.
warning_once
(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`."
)
self
.
mm_encoder_tp_mode
=
"weights"
return
MultiModalConfig
(
limit_per_prompt
=
self
.
limit_mm_per_prompt
,
media_io_kwargs
=
self
.
media_io_kwargs
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
interleave_mm_strings
=
self
.
interleave_mm_strings
,
skip_mm_profiling
=
self
.
skip_mm_profiling
)
skip_mm_profiling
=
self
.
skip_mm_profiling
,
)
return
None
def
set_mm_processor_cache_gb
(
self
,
value
:
int
)
->
None
:
mm_config
=
self
.
get_multimodal_config
()
self
.
mm_processor_cache_gb
=
value
mm_config
.
mm_processor_cache_gb
=
value
def
_get_encoder_config
(
self
):
return
get_sentence_transformer_tokenizer_config
(
self
.
model
,
self
.
revision
)
...
...
@@ -1090,9 +1126,20 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
me_quant
.
QUANTIZATION_METHODS
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"inc"
"fp8"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"inc"
,
"petit_nvfp4"
,
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
cast
(
me_quant
.
QuantizationMethods
,
...
...
@@ -1115,7 +1162,6 @@ class ModelConfig:
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides
=
[
"marlin"
,
"bitblas"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
...
...
@@ -1125,6 +1171,7 @@ class ModelConfig:
"moe_wna16"
,
"modelopt"
,
"modelopt_fp4"
,
"petit_nvfp4"
,
]
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
...
...
@@ -1457,7 +1504,8 @@ class ModelConfig:
from
vllm.distributed.utils
import
get_pp_indices
if
(
self
.
hf_text_config
.
model_type
==
"deepseek_mtp"
or
self
.
hf_config
.
model_type
==
"mimo_mtp"
or
self
.
hf_config
.
model_type
==
"glm4_moe_mtp"
):
or
self
.
hf_config
.
model_type
==
"glm4_moe_mtp"
or
self
.
hf_config
.
model_type
==
"ernie_mtp"
):
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
else
:
...
...
@@ -1657,29 +1705,8 @@ class ModelConfig:
return
self
.
multimodal_config
is
not
None
@
property
def
processor_return_mm_hashes
(
self
)
->
bool
:
"""Whether the multi-modal processor should output hashes."""
mm_config
=
self
.
multimodal_config
if
mm_config
is
None
:
return
False
return
mm_config
.
mm_processor_cache_gb
>
0
@
property
def
enable_mm_processor_cache
(
self
)
->
bool
:
"""Whether the multi-modal processor cache should be enabled."""
mm_config
=
self
.
multimodal_config
if
mm_config
is
None
:
return
False
return
mm_config
.
mm_processor_cache_gb
>
0
def
get_mm_input_cache_gb
(
self
)
->
int
:
mm_config
=
self
.
multimodal_config
if
mm_config
is
None
:
return
0
return
envs
.
VLLM_MM_INPUT_CACHE_GIB
def
is_multimodal_raw_input_only_model
(
self
)
->
bool
:
return
self
.
_model_info
.
supports_multimodal_raw_input_only
@
property
def
is_cross_encoder
(
self
)
->
bool
:
...
...
@@ -1690,10 +1717,6 @@ class ModelConfig:
def
is_pp_supported
(
self
)
->
bool
:
return
self
.
_model_info
.
supports_pp
@
property
def
is_multimodal_raw_input_supported
(
self
)
->
bool
:
return
self
.
_model_info
.
supports_multimodal_raw_input
@
property
def
is_attention_free
(
self
)
->
bool
:
return
self
.
_model_info
.
is_attention_free
...
...
@@ -1904,7 +1927,8 @@ class DeviceConfig:
SpeculativeMethod
=
Literal
[
"ngram"
,
"eagle"
,
"eagle3"
,
"medusa"
,
"mlp_speculator"
,
"draft_model"
,
"deepseek_mtp"
]
"mlp_speculator"
,
"draft_model"
,
"deepseek_mtp"
,
"ernie_mtp"
]
@
config
...
...
@@ -2037,6 +2061,16 @@ class SpeculativeConfig:
"architectures"
:
[
"Glm4MoeMTPModel"
]
})
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
hf_config
.
model_type
=
"ernie_mtp"
if
hf_config
.
model_type
==
"ernie_mtp"
:
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
({
"n_predict"
:
n_predict
,
"architectures"
:
[
"ErnieMTPModel"
]
})
return
hf_config
return
hf_config
def
__post_init__
(
self
):
...
...
@@ -2055,8 +2089,8 @@ class SpeculativeConfig:
if
self
.
target_model_config
and
\
(
self
.
target_model_config
.
hf_text_config
.
model_type
\
==
"deepseek_v3"
or
self
.
target_model_config
.
hf_text_config
.
model_type
\
==
"mimo"
):
self
.
target_model_config
.
hf_text_config
.
model_type
in
(
"mimo"
,
"ernie4_5_moe"
)
):
# use the draft model from the same model:
self
.
model
=
self
.
target_model_config
.
model
elif
self
.
method
in
(
"ngram"
,
"[ngram]"
):
...
...
@@ -2154,6 +2188,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes "
\
"to support multiple layers."
)
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
"ernie_mtp"
):
self
.
method
=
"ernie_mtp"
if
self
.
num_speculative_tokens
>
1
:
logger
.
warning
(
"All Ernie MTP models only have "
\
"one layer. Might need some code changes "
\
"to support multiple layers."
)
else
:
self
.
method
=
"draft_model"
raise
NotImplementedError
(
...
...
@@ -2369,7 +2412,7 @@ class SpeculativeConfig:
return
self
.
num_speculative_tokens
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"deepseek_mtp"
)
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"deepseek_mtp"
,
"ernie_mtp"
)
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
...
...
@@ -2401,8 +2444,8 @@ class LoRAConfig:
lora_dtype
:
Union
[
torch
.
dtype
,
LoRADType
]
=
"auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size
:
int
=
256
"""Maximum size of extra vocabulary that can be present in a
LoRA adapter
(added to the base model vocabulary)
."""
"""
(Deprecated)
Maximum size of extra vocabulary that can be present in a
LoRA adapter. Will be removed in v0.12.0
."""
lora_vocab_padding_size
:
ClassVar
[
int
]
=
current_platform
\
.
get_lora_vocab_padding_size
()
...
...
@@ -2444,6 +2487,12 @@ class LoRAConfig:
return
hash_str
def
__post_init__
(
self
):
# Deprecation warning for lora_extra_vocab_size
logger
.
warning
(
"`lora_extra_vocab_size` is deprecated and will be removed "
"in v0.12.0. Additional vocabulary support for "
"LoRA adapters is being phased out."
)
# Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
possible_max_ranks
=
(
8
,
16
,
32
,
64
,
128
,
256
,
320
,
512
)
...
...
@@ -2508,7 +2557,7 @@ class MultiModalConfig:
`{"num_crops": 4}`.
"""
mm_processor_cache_gb
:
in
t
=
4
mm_processor_cache_gb
:
floa
t
=
4
"""
The size (in GiB) of the multi-modal processor cache, which is used to
...
...
@@ -2519,6 +2568,22 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended).
"""
mm_encoder_tp_mode
:
MMEncoderTPMode
=
"weights"
"""
Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.
"""
interleave_mm_strings
:
bool
=
False
"""
Enable fully interleaved support for multimodal prompts.
...
...
@@ -2988,7 +3053,8 @@ def get_served_model_name(model: str,
return
served_model_name
GuidedDecodingBackend
=
Literal
[
"auto"
,
"xgrammar"
,
"guidance"
,
"outlines"
]
GuidedDecodingBackend
=
Literal
[
"auto"
,
"xgrammar"
,
"guidance"
,
"outlines"
,
"lm-format-enforcer"
]
@
config
...
...
@@ -3551,7 +3617,7 @@ class VllmConfig:
if
self
.
compilation_config
.
pass_config
.
enable_sequence_parallelism
:
self
.
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_xpu
()
:
# if cudagraph_mode is not explicitly set by users, set default
# value
if
self
.
compilation_config
.
cudagraph_mode
is
None
:
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment