Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eabe123
Commit
4eabe123
authored
May 28, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori
parents
45840cd2
58738772
Changes
670
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
427 additions
and
81 deletions
+427
-81
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+1
-1
vllm/transformers_utils/__init__.py
vllm/transformers_utils/__init__.py
+2
-2
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+14
-12
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+4
-2
vllm/transformers_utils/processors/ovis.py
vllm/transformers_utils/processors/ovis.py
+4
-2
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+2
-2
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+6
-2
vllm/utils.py
vllm/utils.py
+30
-20
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+30
-6
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+2
-1
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+25
-16
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+3
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+5
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+13
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+25
-4
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+6
-3
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+2
-2
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+4
-4
vllm/v1/metrics/reader.py
vllm/v1/metrics/reader.py
+245
-0
vllm/v1/request.py
vllm/v1/request.py
+4
-0
No files found.
vllm/spec_decode/spec_decode_worker.py
View file @
4eabe123
...
@@ -114,7 +114,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -114,7 +114,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
return
spec_decode_worker
return
spec_decode_worker
# Reminder: Please update docs/
source/
features/compatibility_matrix.md
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
# If the feature combo become valid
class
SpecDecodeWorker
(
LoRANotSupportedWorkerBase
):
class
SpecDecodeWorker
(
LoRANotSupportedWorkerBase
):
"""Worker which implements speculative decoding.
"""Worker which implements speculative decoding.
...
...
vllm/transformers_utils/__init__.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
vllm
.envs
import
VLLM_USE_MODELSCOPE
from
vllm
import
envs
if
VLLM_USE_MODELSCOPE
:
if
envs
.
VLLM_USE_MODELSCOPE
:
try
:
try
:
# Patch here, before each import happens
# Patch here, before each import happens
import
modelscope
import
modelscope
...
...
vllm/transformers_utils/config.py
View file @
4eabe123
...
@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import (
...
@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
from
transformers.utils
import
CONFIG_NAME
as
HF_CONFIG_NAME
from
transformers.utils
import
CONFIG_NAME
as
HF_CONFIG_NAME
from
vllm
.envs
import
VLLM_USE_MODELSCOPE
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -45,13 +45,12 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
...
@@ -45,13 +45,12 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
resolve_obj_by_qualname
if
VLLM_USE_MODELSCOPE
:
if
envs
.
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
from
modelscope
import
AutoConfig
else
:
else
:
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
MISTRAL_CONFIG_NAME
=
"params.json"
MISTRAL_CONFIG_NAME
=
"params.json"
HF_TOKEN
=
os
.
getenv
(
'HF_TOKEN'
,
None
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -130,7 +129,7 @@ def list_repo_files(
...
@@ -130,7 +129,7 @@ def list_repo_files(
]
]
# if model is remote, use hf_hub api to list files
# if model is remote, use hf_hub api to list files
try
:
try
:
if
VLLM_USE_MODELSCOPE
:
if
envs
.
VLLM_USE_MODELSCOPE
:
from
vllm.transformers_utils.utils
import
(
from
vllm.transformers_utils.utils
import
(
modelscope_list_repo_files
)
modelscope_list_repo_files
)
return
modelscope_list_repo_files
(
repo_id
,
return
modelscope_list_repo_files
(
repo_id
,
...
@@ -185,7 +184,7 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
...
@@ -185,7 +184,7 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
return
file_exists
(
str
(
model
),
return
file_exists
(
str
(
model
),
config_name
,
config_name
,
revision
=
revision
,
revision
=
revision
,
token
=
HF_TOKEN
)
token
=
os
.
getenv
(
'HF_TOKEN'
,
None
)
)
def
patch_rope_scaling
(
config
:
PretrainedConfig
)
->
None
:
def
patch_rope_scaling
(
config
:
PretrainedConfig
)
->
None
:
...
@@ -300,7 +299,10 @@ def get_config(
...
@@ -300,7 +299,10 @@ def get_config(
" - For Hugging Face models: ensure the presence of a "
" - For Hugging Face models: ensure the presence of a "
"'config.json'.
\n
"
"'config.json'.
\n
"
" - For Mistral models: ensure the presence of a "
" - For Mistral models: ensure the presence of a "
"'params.json'.
\n
"
).
format
(
model
=
model
)
"'params.json'.
\n
"
"3. For GGUF: pass the local path of the GGUF checkpoint.
\n
"
" Loading GGUF from a remote repo directly is not yet "
"supported.
\n
"
).
format
(
model
=
model
)
raise
ValueError
(
error_message
)
from
e
raise
ValueError
(
error_message
)
from
e
...
@@ -309,7 +311,7 @@ def get_config(
...
@@ -309,7 +311,7 @@ def get_config(
model
,
model
,
revision
=
revision
,
revision
=
revision
,
code_revision
=
code_revision
,
code_revision
=
code_revision
,
token
=
HF_TOKEN
,
token
=
os
.
getenv
(
'HF_TOKEN'
,
None
)
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -321,7 +323,7 @@ def get_config(
...
@@ -321,7 +323,7 @@ def get_config(
model
,
model
,
revision
=
revision
,
revision
=
revision
,
code_revision
=
code_revision
,
code_revision
=
code_revision
,
token
=
HF_TOKEN
,
token
=
os
.
getenv
(
'HF_TOKEN'
,
None
)
,
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
...
@@ -331,7 +333,7 @@ def get_config(
...
@@ -331,7 +333,7 @@ def get_config(
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
revision
=
revision
,
code_revision
=
code_revision
,
code_revision
=
code_revision
,
token
=
HF_TOKEN
,
token
=
os
.
getenv
(
'HF_TOKEN'
,
None
)
,
**
kwargs
,
**
kwargs
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
...
@@ -349,7 +351,7 @@ def get_config(
...
@@ -349,7 +351,7 @@ def get_config(
raise
e
raise
e
elif
config_format
==
ConfigFormat
.
MISTRAL
:
elif
config_format
==
ConfigFormat
.
MISTRAL
:
config
=
load_params_config
(
model
,
revision
,
token
=
HF_TOKEN
,
**
kwargs
)
config
=
load_params_config
(
model
,
revision
,
**
kwargs
)
else
:
else
:
supported_formats
=
[
supported_formats
=
[
fmt
.
value
for
fmt
in
ConfigFormat
if
fmt
!=
ConfigFormat
.
AUTO
fmt
.
value
for
fmt
in
ConfigFormat
if
fmt
!=
ConfigFormat
.
AUTO
...
@@ -558,7 +560,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
...
@@ -558,7 +560,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
# If model is on HuggingfaceHub, get the repo files
# If model is on HuggingfaceHub, get the repo files
repo_files
=
list_repo_files
(
model
,
repo_files
=
list_repo_files
(
model
,
revision
=
revision
,
revision
=
revision
,
token
=
HF_TOKEN
)
token
=
os
.
getenv
(
'HF_TOKEN'
,
None
)
)
except
Exception
:
except
Exception
:
repo_files
=
[]
repo_files
=
[]
...
@@ -765,7 +767,7 @@ def get_hf_image_processor_config(
...
@@ -765,7 +767,7 @@ def get_hf_image_processor_config(
**
kwargs
,
**
kwargs
,
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
# ModelScope does not provide an interface for image_processor
# ModelScope does not provide an interface for image_processor
if
VLLM_USE_MODELSCOPE
:
if
envs
.
VLLM_USE_MODELSCOPE
:
return
dict
()
return
dict
()
# Separate model folder from file path for GGUF models
# Separate model folder from file path for GGUF models
if
check_gguf_file
(
model
):
if
check_gguf_file
(
model
):
...
...
vllm/transformers_utils/configs/eagle.py
View file @
4eabe123
...
@@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig):
...
@@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig):
assert
self
.
model
is
not
None
,
\
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle"
"model should not be None when method is eagle"
kwargs
[
"architectures"
]
=
[
kwargs
[
"architectures"
]
=
[
f
"Eagle
{
arch
}
"
for
arch
in
self
.
model
.
architectures
f
"Eagle
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
]
elif
method
==
"eagle3"
:
elif
method
==
"eagle3"
:
assert
self
.
model
is
not
None
,
\
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle3"
"model should not be None when method is eagle3"
kwargs
[
"architectures"
]
=
[
kwargs
[
"architectures"
]
=
[
f
"Eagle3
{
arch
}
"
for
arch
in
self
.
model
.
architectures
f
"Eagle3
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle3"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
]
else
:
else
:
raise
ValueError
(
f
"Invalid method
{
method
}
.
\
raise
ValueError
(
f
"Invalid method
{
method
}
.
\
...
...
vllm/transformers_utils/processors/ovis.py
View file @
4eabe123
...
@@ -33,6 +33,8 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
...
@@ -33,6 +33,8 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack
)
Unpack
)
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
from
vllm.multimodal.image
import
convert_image_mode
__all__
=
[
'OvisProcessor'
]
__all__
=
[
'OvisProcessor'
]
IGNORE_ID
=
-
100
IGNORE_ID
=
-
100
...
@@ -361,8 +363,8 @@ class OvisProcessor(ProcessorMixin):
...
@@ -361,8 +363,8 @@ class OvisProcessor(ProcessorMixin):
# pick the partition with maximum covering_ratio and break the tie using #sub_images
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return
sorted
(
all_grids
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
][
0
]
*
x
[
0
][
1
]))[
0
][
0
]
return
sorted
(
all_grids
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
][
0
]
*
x
[
0
][
1
]))[
0
][
0
]
if
convert_to_rgb
and
image
.
mode
!=
'RGB'
:
if
convert_to_rgb
:
image
=
image
.
convert
(
'RGB'
)
image
=
convert
_image_mode
(
image
,
'RGB'
)
sides
=
self
.
get_image_size
()
sides
=
self
.
get_image_size
()
...
...
vllm/transformers_utils/tokenizer.py
View file @
4eabe123
...
@@ -13,7 +13,7 @@ import huggingface_hub
...
@@ -13,7 +13,7 @@ import huggingface_hub
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
PreTrainedTokenizerFast
)
from
vllm
.envs
import
VLLM_USE_MODELSCOPE
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
...
@@ -168,7 +168,7 @@ def get_tokenizer(
...
@@ -168,7 +168,7 @@ def get_tokenizer(
)
->
AnyTokenizer
:
)
->
AnyTokenizer
:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
"""
if
VLLM_USE_MODELSCOPE
:
if
envs
.
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
# pylint: disable=C.
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
os
import
re
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
,
cast
import
huggingface_hub
import
huggingface_hub
import
regex
as
re
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -156,7 +156,11 @@ def make_mistral_chat_completion_request(
...
@@ -156,7 +156,11 @@ def make_mistral_chat_completion_request(
#
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
for
message
in
messages
:
for
message
in
messages
:
if
message
.
get
(
"role"
)
==
"assistant"
:
# Remove reasoning_content as unsupported by Mistral
_
=
message
.
pop
(
"reasoning_content"
,
None
)
# type: ignore
# Convert list text content to string
if
message
.
get
(
"role"
)
in
(
"assistant"
,
"tool"
):
content
=
message
.
get
(
"content"
)
content
=
message
.
get
(
"content"
)
if
isinstance
(
content
,
list
):
if
isinstance
(
content
,
list
):
content
=
"
\n
"
.
join
(
chunk
.
get
(
"text"
)
for
chunk
in
content
)
content
=
"
\n
"
.
join
(
chunk
.
get
(
"text"
)
for
chunk
in
content
)
...
...
vllm/utils.py
View file @
4eabe123
...
@@ -19,7 +19,6 @@ import json
...
@@ -19,7 +19,6 @@ import json
import
multiprocessing
import
multiprocessing
import
os
import
os
import
pickle
import
pickle
import
re
import
signal
import
signal
import
socket
import
socket
import
subprocess
import
subprocess
...
@@ -34,7 +33,8 @@ import uuid
...
@@ -34,7 +33,8 @@ import uuid
import
warnings
import
warnings
import
weakref
import
weakref
from
argparse
import
(
Action
,
ArgumentDefaultsHelpFormatter
,
ArgumentParser
,
from
argparse
import
(
Action
,
ArgumentDefaultsHelpFormatter
,
ArgumentParser
,
ArgumentTypeError
,
_ArgumentGroup
)
ArgumentTypeError
,
RawDescriptionHelpFormatter
,
_ArgumentGroup
)
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
...
@@ -54,6 +54,7 @@ import cloudpickle
...
@@ -54,6 +54,7 @@ import cloudpickle
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
import
psutil
import
psutil
import
regex
as
re
import
torch
import
torch
import
torch.types
import
torch.types
import
yaml
import
yaml
...
@@ -77,9 +78,15 @@ if TYPE_CHECKING:
...
@@ -77,9 +78,15 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS
=
2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
5120
# Exception strings for non-implemented encoder/decoder scenarios
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/
source/
features/compatibility_matrix.md
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
# If the feature combo become valid
STR_NOT_IMPL_ENC_DEC_SWA
=
\
STR_NOT_IMPL_ENC_DEC_SWA
=
\
...
@@ -752,16 +759,15 @@ def get_kv_cache_torch_dtype(
...
@@ -752,16 +759,15 @@ def get_kv_cache_torch_dtype(
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
)
->
torch
.
dtype
:
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
)
->
torch
.
dtype
:
if
isinstance
(
cache_dtype
,
str
):
if
isinstance
(
cache_dtype
,
str
):
if
cache_dtype
==
"auto"
:
if
cache_dtype
==
"auto"
:
if
isinstance
(
model_dtype
,
str
):
if
isinstance
(
model_dtype
,
str
)
and
model_dtype
in
STR_DTYPE_TO_TORCH_DTYPE
:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
elif
isinstance
(
model_dtype
,
torch
.
dtype
):
elif
isinstance
(
model_dtype
,
torch
.
dtype
):
torch_dtype
=
model_dtype
torch_dtype
=
model_dtype
else
:
else
:
raise
ValueError
(
f
"Invalid model dtype:
{
model_dtype
}
"
)
raise
ValueError
(
f
"Invalid model dtype:
{
model_dtype
}
"
)
elif
cache_dtype
in
[
"half"
,
"bfloat16"
,
"float"
]
:
elif
cache_dtype
in
STR_DTYPE_TO_TORCH_DTYPE
:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
elif
cache_dtype
==
"fp8"
:
torch_dtype
=
torch
.
uint8
else
:
else
:
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
elif
isinstance
(
cache_dtype
,
torch
.
dtype
):
elif
isinstance
(
cache_dtype
,
torch
.
dtype
):
...
@@ -998,7 +1004,7 @@ def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
...
@@ -998,7 +1004,7 @@ def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
def
full_groupby
(
values
:
Iterable
[
_V
],
*
,
key
:
Callable
[[
_V
],
_K
]):
def
full_groupby
(
values
:
Iterable
[
_V
],
*
,
key
:
Callable
[[
_V
],
_K
]):
"""
"""
Unlike
{class}
`itertools.groupby`, groups are not broken by
Unlike
[
`itertools.groupby`
][]
, groups are not broken by
non-contiguous data.
non-contiguous data.
"""
"""
groups
=
defaultdict
[
_K
,
list
[
_V
]](
list
)
groups
=
defaultdict
[
_K
,
list
[
_V
]](
list
)
...
@@ -1318,7 +1324,8 @@ class StoreBoolean(Action):
...
@@ -1318,7 +1324,8 @@ class StoreBoolean(Action):
"Expected 'true' or 'false'."
)
"Expected 'true' or 'false'."
)
class
SortedHelpFormatter
(
ArgumentDefaultsHelpFormatter
):
class
SortedHelpFormatter
(
ArgumentDefaultsHelpFormatter
,
RawDescriptionHelpFormatter
):
"""SortedHelpFormatter that sorts arguments by their option strings."""
"""SortedHelpFormatter that sorts arguments by their option strings."""
def
_split_lines
(
self
,
text
,
width
):
def
_split_lines
(
self
,
text
,
width
):
...
@@ -1919,11 +1926,11 @@ class _PlaceholderBase:
...
@@ -1919,11 +1926,11 @@ class _PlaceholderBase:
Disallows downstream usage of placeholder modules.
Disallows downstream usage of placeholder modules.
We need to explicitly override each dunder method because
We need to explicitly override each dunder method because
{meth}`__getattr__` is not called when they are accessed.
[`__getattr__`][vllm.utils._PlaceholderBase.__getattr__]
is not called when they are accessed.
:::{seealso}
Info:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
:::
"""
"""
def
__getattr__
(
self
,
key
:
str
)
->
Never
:
def
__getattr__
(
self
,
key
:
str
)
->
Never
:
...
@@ -2522,7 +2529,7 @@ def _maybe_force_spawn():
...
@@ -2522,7 +2529,7 @@ def _maybe_force_spawn():
logger
.
warning
(
logger
.
warning
(
"We must use the `spawn` multiprocessing start method. "
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/
getting_started
/"
"See https://docs.vllm.ai/en/latest/
usage
/"
"troubleshooting.html#python-multiprocessing "
"troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s"
,
reason
)
"for more information. Reason: %s"
,
reason
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
...
@@ -2787,14 +2794,17 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
...
@@ -2787,14 +2794,17 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
# Only relevant for models using ALiBi (e.g, MPT)
# Only relevant for models using ALiBi (e.g, MPT)
def
check_use_alibi
(
model_config
:
ModelConfig
)
->
bool
:
def
check_use_alibi
(
model_config
:
ModelConfig
)
->
bool
:
return
(
getattr
(
model_config
.
hf_text_config
,
"alibi"
,
False
)
# Falcon
cfg
=
model_config
.
hf_text_config
return
(
getattr
(
cfg
,
"alibi"
,
False
)
# Falcon
or
(
"BloomForCausalLM"
in
getattr
(
model_config
.
hf_config
,
or
(
"BloomForCausalLM"
in
getattr
(
model_config
.
hf_config
,
"architectures"
,
[]))
# Bloom
"architectures"
,
[]))
# Bloom
or
getattr
(
model_config
.
hf_text_config
,
"position_encoding_type"
,
or
getattr
(
cfg
,
"position_encoding_type"
,
""
)
==
""
)
==
"alibi"
# codellm_1b_alibi
"alibi"
# codellm_1b_alibi
or
or
(
hasattr
(
cfg
,
"attn_config"
)
# MPT
(
hasattr
(
model_config
.
hf_text_config
,
"attn_config"
)
# MPT
and
((
isinstance
(
cfg
.
attn_config
,
dict
)
and
model_config
.
hf_text_config
.
attn_config
.
get
(
"alibi"
,
False
)))
and
cfg
.
attn_config
.
get
(
"alibi"
,
False
))
or
(
not
isinstance
(
cfg
.
attn_config
,
dict
)
and
getattr
(
cfg
.
attn_config
,
"alibi"
,
False
)))))
def
sha256
(
input
)
->
int
:
def
sha256
(
input
)
->
int
:
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
4eabe123
...
@@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
...
@@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The number of entries in the last page of each request in
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
paged_kv_last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
# The query indptr, shape : [num_decode + 1]
qo_indptr
:
Optional
[
torch
.
Tensor
]
=
None
class
AiterMLAMetadata
(
MLACommonMetadata
[
AiterMLADecodeMetadata
]):
class
AiterMLAMetadata
(
MLACommonMetadata
[
AiterMLADecodeMetadata
]):
...
@@ -75,27 +77,33 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -75,27 +77,33 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
seq_lens
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
page_size
=
self
.
kv_cache_spec
.
block_size
page_size
=
self
.
kv_cache_spec
.
block_size
block_table_bounds
=
(
seq_lens
+
page_size
-
1
)
//
page_size
block_table_bounds
=
(
seq_lens
+
page_size
-
1
)
//
page_size
device
=
self
.
runner
.
device
mask
=
(
torch
.
arange
(
block_table
.
size
(
1
),
mask
=
(
torch
.
arange
(
block_table
.
size
(
1
),
dtype
=
block_table
.
dtype
,
dtype
=
block_table
.
dtype
,
device
=
block_table
.
device
).
unsqueeze
(
0
)
device
=
device
).
unsqueeze
(
0
)
<
block_table_bounds
.
unsqueeze
(
1
))
<
block_table_bounds
.
unsqueeze
(
1
))
paged_kv_indices
=
block_table
[
mask
]
paged_kv_indices
=
block_table
[
mask
]
paged_kv_indptr
=
torch
.
cat
([
paged_kv_indptr
=
torch
.
cat
([
torch
.
zeros
(
1
,
torch
.
zeros
(
1
,
dtype
=
block_table_bounds
.
dtype
,
device
=
device
),
dtype
=
block_table_bounds
.
dtype
,
device
=
block_table_bounds
.
device
),
block_table_bounds
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
block_table_bounds
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
])
])
paged_kv_last_page_len
=
seq_lens
%
page_size
paged_kv_last_page_len
=
seq_lens
%
page_size
paged_kv_last_page_len
=
torch
.
where
(
paged_kv_last_page_len
==
0
,
paged_kv_last_page_len
=
torch
.
where
(
paged_kv_last_page_len
==
0
,
page_size
,
paged_kv_last_page_len
)
page_size
,
paged_kv_last_page_len
)
qo_indptr
=
torch
.
arange
(
0
,
self
.
_num_decodes
+
1
,
step
=
1
,
dtype
=
torch
.
int32
,
device
=
device
)
return
(
return
(
paged_kv_indices
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_indptr
,
paged_kv_last_page_len
,
paged_kv_last_page_len
,
qo_indptr
,
)
)
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
...
@@ -105,6 +113,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -105,6 +113,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_indices
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_indptr
,
paged_last_page_len
,
paged_last_page_len
,
qo_indptr
,
)
=
self
.
_get_paged_kv_tensors
(
block_table_tensor
,
seq_lens
)
)
=
self
.
_get_paged_kv_tensors
(
block_table_tensor
,
seq_lens
)
attn_metadata
=
AiterMLADecodeMetadata
(
attn_metadata
=
AiterMLADecodeMetadata
(
...
@@ -112,7 +121,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -112,7 +121,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
paged_kv_indptr
=
paged_kv_indptr
,
paged_kv_indptr
=
paged_kv_indptr
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len
=
paged_last_page_len
)
paged_kv_last_page_len
=
paged_last_page_len
,
qo_indptr
=
qo_indptr
)
return
attn_metadata
return
attn_metadata
...
@@ -137,7 +147,10 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -137,7 +147,10 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
**
mla_args
)
assert
(
num_heads
==
16
or
num_heads
==
128
),
(
f
"Aiter MLA only supports 16 or 128 number of heads.
\n
"
f
"Provided
{
num_heads
}
number of heads.
\n
"
"Try adjusting tensor_parallel_size value."
)
unsupported_features
=
[
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
]
...
@@ -189,7 +202,18 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -189,7 +202,18 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
if
self
.
num_heads
==
16
:
# AITER MLA decode kernel only supports
# max_seqlen_q=1 when using 16 heads.
max_seqlen_qo
=
1
else
:
# AITER MLA decode Kernel handles arbitrary
# max_seqlen_q values when using 128 heads.
assert
attn_metadata
.
prefill
is
not
None
max_seqlen_qo
=
attn_metadata
.
prefill
.
max_query_len
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
decode
.
qo_indptr
,
max_seqlen_qo
,
attn_metadata
.
decode
.
paged_kv_indptr
,
attn_metadata
.
decode
.
paged_kv_indptr
,
attn_metadata
.
decode
.
paged_kv_indices
,
attn_metadata
.
decode
.
paged_kv_indices
,
attn_metadata
.
decode
.
paged_kv_last_page_len
)
attn_metadata
.
decode
.
paged_kv_last_page_len
)
...
...
vllm/v1/core/kv_cache_manager.py
View file @
4eabe123
...
@@ -174,6 +174,7 @@ class KVCacheManager:
...
@@ -174,6 +174,7 @@ class KVCacheManager:
num_new_tokens
:
int
,
num_new_tokens
:
int
,
num_new_computed_tokens
:
int
=
0
,
num_new_computed_tokens
:
int
=
0
,
new_computed_blocks
:
Optional
[
KVCacheBlocks
]
=
None
,
new_computed_blocks
:
Optional
[
KVCacheBlocks
]
=
None
,
num_draft_tokens
:
int
=
0
,
num_lookahead_tokens
:
int
=
0
,
num_lookahead_tokens
:
int
=
0
,
delay_cache_blocks
:
bool
=
False
,
delay_cache_blocks
:
bool
=
False
,
)
->
Optional
[
KVCacheBlocks
]:
)
->
Optional
[
KVCacheBlocks
]:
...
@@ -273,7 +274,7 @@ class KVCacheManager:
...
@@ -273,7 +274,7 @@ class KVCacheManager:
# generated (accepted) tokens.
# generated (accepted) tokens.
self
.
single_type_manager
.
cache_blocks
(
self
.
single_type_manager
.
cache_blocks
(
request
,
self
.
req_to_block_hashes
[
request
.
request_id
],
request
,
self
.
req_to_block_hashes
[
request
.
request_id
],
num_computed_tokens
+
num_new_tokens
-
len
(
request
.
spec
_token
_id
s
)
)
num_computed_tokens
+
num_new_tokens
-
num_draft
_tokens
)
return
KVCacheBlocks
(
new_blocks
)
return
KVCacheBlocks
(
new_blocks
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
4eabe123
...
@@ -227,10 +227,15 @@ class Scheduler(SchedulerInterface):
...
@@ -227,10 +227,15 @@ class Scheduler(SchedulerInterface):
req_index
+=
1
req_index
+=
1
continue
continue
num_draft_tokens
=
max
(
num_new_tokens
+
request
.
num_computed_tokens
-
request
.
num_tokens
,
0
)
while
True
:
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
request
,
num_new_tokens
,
num_new_tokens
,
num_draft_tokens
=
num_draft_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
)
num_lookahead_tokens
=
self
.
num_lookahead_tokens
)
if
new_blocks
is
None
:
if
new_blocks
is
None
:
# The request cannot be scheduled.
# The request cannot be scheduled.
...
@@ -310,15 +315,16 @@ class Scheduler(SchedulerInterface):
...
@@ -310,15 +315,16 @@ class Scheduler(SchedulerInterface):
break
break
request
=
self
.
waiting
[
0
]
request
=
self
.
waiting
[
0
]
num_prealloc_computed_tokens
=
0
#
P/D
: skip request if still waiting for remote kvs.
#
KVTransfer
: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
if
is_ready
:
if
is_ready
:
request
.
status
=
RequestStatus
.
WAITING
request
.
status
=
RequestStatus
.
WAITING
num_prealloc_computed_tokens
=
(
request
.
num_computed_tokens
)
else
:
else
:
logger
.
debug
(
"%s is still in WAITING_FOR_REMOTE_KVS state."
,
request
.
request_id
)
self
.
waiting
.
popleft
()
self
.
waiting
.
popleft
()
skipped_waiting_requests
.
appendleft
(
request
)
skipped_waiting_requests
.
appendleft
(
request
)
continue
continue
...
@@ -349,8 +355,9 @@ class Scheduler(SchedulerInterface):
...
@@ -349,8 +355,9 @@ class Scheduler(SchedulerInterface):
load_kv_async
=
False
load_kv_async
=
False
# Get already-cached tokens.
# Get already-cached tokens.
if
num_prealloc_computed_tokens
==
0
:
if
request
.
num_computed_tokens
==
0
:
new_computed_blocks
,
num_native_computed_tokens
=
\
# Get locally-cached tokens.
new_computed_blocks
,
num_new_local_computed_tokens
=
\
self
.
kv_cache_manager
.
get_computed_blocks
(
self
.
kv_cache_manager
.
get_computed_blocks
(
request
)
request
)
...
@@ -358,23 +365,22 @@ class Scheduler(SchedulerInterface):
...
@@ -358,23 +365,22 @@ class Scheduler(SchedulerInterface):
if
self
.
connector
is
not
None
:
if
self
.
connector
is
not
None
:
num_external_computed_tokens
,
load_kv_async
=
(
num_external_computed_tokens
,
load_kv_async
=
(
self
.
connector
.
get_num_new_matched_tokens
(
self
.
connector
.
get_num_new_matched_tokens
(
request
,
num_n
ative
_computed_tokens
))
request
,
num_n
ew_local
_computed_tokens
))
# Total computed tokens (local + external).
# Total computed tokens (local + external).
num_computed_tokens
=
(
num_n
ative
_computed_tokens
+
num_computed_tokens
=
(
num_n
ew_local
_computed_tokens
+
num_external_computed_tokens
)
num_external_computed_tokens
)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else
:
else
:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks
=
KVCacheBlocks
.
create_empty
()
new_computed_blocks
=
KVCacheBlocks
.
create_empty
()
num_native_computed_tokens
=
0
num_new_local_computed_tokens
=
0
num_computed_tokens
=
request
.
num_computed_tokens
# Total computed tokens (allocated in prior step).
num_computed_tokens
=
num_prealloc_computed_tokens
encoder_inputs_to_schedule
=
None
encoder_inputs_to_schedule
=
None
new_encoder_budget
=
encoder_budget
new_encoder_budget
=
encoder_budget
#
P/D
: loading remote KV, do not allocate for new work.
#
KVTransfer
: loading remote KV, do not allocate for new work.
if
load_kv_async
:
if
load_kv_async
:
assert
num_external_computed_tokens
>
0
assert
num_external_computed_tokens
>
0
num_new_tokens
=
0
num_new_tokens
=
0
...
@@ -405,7 +411,7 @@ class Scheduler(SchedulerInterface):
...
@@ -405,7 +411,7 @@ class Scheduler(SchedulerInterface):
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
request
,
num_new_tokens
+
num_external_computed_tokens
,
num_new_tokens
+
num_external_computed_tokens
,
num_n
ative
_computed_tokens
,
num_n
ew_local
_computed_tokens
,
new_computed_blocks
,
new_computed_blocks
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
,
delay_cache_blocks
=
load_kv_async
,
delay_cache_blocks
=
load_kv_async
,
...
@@ -457,7 +463,9 @@ class Scheduler(SchedulerInterface):
...
@@ -457,7 +463,9 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
request
.
num_computed_tokens
=
num_computed_tokens
# Count the number of prifix cached tokens.
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
# Encoder-related.
# Encoder-related.
if
encoder_inputs_to_schedule
:
if
encoder_inputs_to_schedule
:
scheduled_encoder_inputs
[
request
.
request_id
]
=
(
scheduled_encoder_inputs
[
request
.
request_id
]
=
(
...
@@ -799,6 +807,7 @@ class Scheduler(SchedulerInterface):
...
@@ -799,6 +807,7 @@ class Scheduler(SchedulerInterface):
stop_reason
=
request
.
stop_reason
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
))
else
:
else
:
...
...
vllm/v1/engine/__init__.py
View file @
4eabe123
...
@@ -107,6 +107,9 @@ class EngineCoreOutput(
...
@@ -107,6 +107,9 @@ class EngineCoreOutput(
events
:
Optional
[
list
[
EngineCoreEvent
]]
=
None
events
:
Optional
[
list
[
EngineCoreEvent
]]
=
None
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
# The number of tokens with prefix cache hits.
num_cached_tokens
:
int
=
0
@
property
@
property
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
return
self
.
finish_reason
is
not
None
return
self
.
finish_reason
is
not
None
...
...
vllm/v1/engine/async_llm.py
View file @
4eabe123
...
@@ -20,6 +20,8 @@ from vllm.outputs import RequestOutput
...
@@ -20,6 +20,8 @@ from vllm.outputs import RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -80,6 +82,9 @@ class AsyncLLM(EngineClient):
...
@@ -80,6 +82,9 @@ class AsyncLLM(EngineClient):
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github."
)
"VLLM_USE_V1=0 or 1 and report this issue on Github."
)
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value
()
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
...
...
vllm/v1/engine/core.py
View file @
4eabe123
...
@@ -57,6 +57,10 @@ class EngineCore:
...
@@ -57,6 +57,10 @@ class EngineCore:
executor_fail_callback
:
Optional
[
Callable
]
=
None
):
executor_fail_callback
:
Optional
[
Callable
]
=
None
):
assert
vllm_config
.
model_config
.
runner_type
!=
"pooling"
assert
vllm_config
.
model_config
.
runner_type
!=
"pooling"
# plugins need to be loaded at the engine/scheduler level too
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
logger
.
info
(
"Initializing a V1 LLM engine (v%s) with config: %s"
,
logger
.
info
(
"Initializing a V1 LLM engine (v%s) with config: %s"
,
VLLM_VERSION
,
vllm_config
)
VLLM_VERSION
,
vllm_config
)
...
@@ -336,6 +340,13 @@ class EngineCore:
...
@@ -336,6 +340,13 @@ class EngineCore:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
kwargs
)
def
save_tensorized_model
(
self
,
tensorizer_config
,
)
->
None
:
self
.
model_executor
.
save_tensorized_model
(
tensorizer_config
=
tensorizer_config
,
)
class
EngineCoreProc
(
EngineCore
):
class
EngineCoreProc
(
EngineCore
):
"""ZMQ-wrapper for running EngineCore in background process."""
"""ZMQ-wrapper for running EngineCore in background process."""
...
@@ -706,7 +717,7 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -706,7 +717,7 @@ class DPEngineCoreProc(EngineCoreProc):
for
i
in
range
(
local_dp_rank
*
world_size
,
(
local_dp_rank
+
1
)
*
for
i
in
range
(
local_dp_rank
*
world_size
,
(
local_dp_rank
+
1
)
*
world_size
))
world_size
))
self
.
local_
dp_rank
=
local_
dp_rank
self
.
dp_rank
=
dp_rank
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
current_wave
=
0
self
.
current_wave
=
0
...
@@ -779,7 +790,7 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -779,7 +790,7 @@ class DPEngineCoreProc(EngineCoreProc):
local_unfinished_reqs
)
local_unfinished_reqs
)
if
not
self
.
engines_running
:
if
not
self
.
engines_running
:
if
self
.
local_
dp_rank
==
0
:
if
self
.
dp_rank
==
0
:
# Notify client that we are pausing the loop.
# Notify client that we are pausing the loop.
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
self
.
current_wave
)
self
.
current_wave
)
...
...
vllm/v1/engine/llm_engine.py
View file @
4eabe123
...
@@ -27,7 +27,10 @@ from vllm.v1.engine.output_processor import OutputProcessor
...
@@ -27,7 +27,10 @@ from vllm.v1.engine.output_processor import OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.loggers
import
StatLoggerFactory
from
vllm.v1.metrics.loggers
import
(
PrometheusStatLogger
,
StatLoggerBase
,
StatLoggerFactory
)
from
vllm.v1.metrics.reader
import
Metric
,
get_metrics_snapshot
from
vllm.v1.metrics.stats
import
IterationStats
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -64,6 +67,11 @@ class LLMEngine:
...
@@ -64,6 +67,11 @@ class LLMEngine:
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
log_stats
=
log_stats
self
.
stat_logger
:
Optional
[
StatLoggerBase
]
=
None
if
self
.
log_stats
:
self
.
stat_logger
=
PrometheusStatLogger
(
vllm_config
)
# important: init dp group before init the engine_core
# important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc.
# In the decoupled engine case this is handled in EngineCoreProc.
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
...
@@ -86,7 +94,7 @@ class LLMEngine:
...
@@ -86,7 +94,7 @@ class LLMEngine:
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self
.
output_processor
=
OutputProcessor
(
self
.
tokenizer
,
self
.
output_processor
=
OutputProcessor
(
self
.
tokenizer
,
log_stats
=
False
)
log_stats
=
self
.
log_stats
)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self
.
engine_core
=
EngineCoreClient
.
make_client
(
self
.
engine_core
=
EngineCoreClient
.
make_client
(
...
@@ -94,7 +102,7 @@ class LLMEngine:
...
@@ -94,7 +102,7 @@ class LLMEngine:
asyncio_mode
=
False
,
asyncio_mode
=
False
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_stats
=
False
,
# FIXME: implement
log_stats
=
self
.
log_stats
,
)
)
if
not
multiprocess_mode
:
if
not
multiprocess_mode
:
...
@@ -223,12 +231,21 @@ class LLMEngine:
...
@@ -223,12 +231,21 @@ class LLMEngine:
outputs
=
self
.
engine_core
.
get_output
()
outputs
=
self
.
engine_core
.
get_output
()
# 2) Process EngineCoreOutputs.
# 2) Process EngineCoreOutputs.
iteration_stats
=
IterationStats
()
if
self
.
log_stats
else
None
processed_outputs
=
self
.
output_processor
.
process_outputs
(
processed_outputs
=
self
.
output_processor
.
process_outputs
(
outputs
.
outputs
)
outputs
.
outputs
,
engine_core_timestamp
=
outputs
.
timestamp
,
iteration_stats
=
iteration_stats
)
# 3) Abort any reqs that finished due to stop strings.
# 3) Abort any reqs that finished due to stop strings.
self
.
engine_core
.
abort_requests
(
processed_outputs
.
reqs_to_abort
)
self
.
engine_core
.
abort_requests
(
processed_outputs
.
reqs_to_abort
)
# 4) Record stats
if
self
.
stat_logger
is
not
None
:
assert
outputs
.
scheduler_stats
is
not
None
self
.
stat_logger
.
record
(
scheduler_stats
=
outputs
.
scheduler_stats
,
iteration_stats
=
iteration_stats
)
return
processed_outputs
.
request_outputs
return
processed_outputs
.
request_outputs
def
get_vllm_config
(
self
):
def
get_vllm_config
(
self
):
...
@@ -260,6 +277,10 @@ class LLMEngine:
...
@@ -260,6 +277,10 @@ class LLMEngine:
def
is_sleeping
(
self
)
->
bool
:
def
is_sleeping
(
self
)
->
bool
:
return
self
.
engine_core
.
is_sleeping
()
return
self
.
engine_core
.
is_sleeping
()
def
get_metrics
(
self
)
->
list
[
Metric
]:
assert
self
.
log_stats
,
"Stat logging disabled"
return
get_metrics_snapshot
()
def
get_tokenizer_group
(
self
)
->
TokenizerGroup
:
def
get_tokenizer_group
(
self
)
->
TokenizerGroup
:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"Unable to get tokenizer because "
raise
ValueError
(
"Unable to get tokenizer because "
...
...
vllm/v1/engine/output_processor.py
View file @
4eabe123
...
@@ -147,6 +147,7 @@ class RequestState:
...
@@ -147,6 +147,7 @@ class RequestState:
finish_reason
:
Optional
[
FinishReason
],
finish_reason
:
Optional
[
FinishReason
],
stop_reason
:
Union
[
int
,
str
,
None
],
stop_reason
:
Union
[
int
,
str
,
None
],
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
num_cached_tokens
:
int
=
0
,
)
->
Optional
[
RequestOutput
]:
)
->
Optional
[
RequestOutput
]:
finished
=
finish_reason
is
not
None
finished
=
finish_reason
is
not
None
...
@@ -169,7 +170,7 @@ class RequestState:
...
@@ -169,7 +170,7 @@ class RequestState:
return
None
return
None
return
self
.
_new_request_output
(
request_id
,
outputs
,
finished
,
return
self
.
_new_request_output
(
request_id
,
outputs
,
finished
,
kv_transfer_params
)
kv_transfer_params
,
num_cached_tokens
)
def
_new_request_output
(
def
_new_request_output
(
self
,
self
,
...
@@ -177,6 +178,7 @@ class RequestState:
...
@@ -177,6 +178,7 @@ class RequestState:
outputs
:
list
[
CompletionOutput
],
outputs
:
list
[
CompletionOutput
],
finished
:
bool
,
finished
:
bool
,
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
num_cached_tokens
:
int
=
0
,
)
->
RequestOutput
:
)
->
RequestOutput
:
if
self
.
output_kind
==
RequestOutputKind
.
DELTA
:
if
self
.
output_kind
==
RequestOutputKind
.
DELTA
:
...
@@ -193,6 +195,7 @@ class RequestState:
...
@@ -193,6 +195,7 @@ class RequestState:
outputs
=
outputs
,
outputs
=
outputs
,
finished
=
finished
,
finished
=
finished
,
kv_transfer_params
=
kv_transfer_params
,
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
num_cached_tokens
,
)
)
def
_new_completion_output
(
def
_new_completion_output
(
...
@@ -340,7 +343,7 @@ class OutputProcessor:
...
@@ -340,7 +343,7 @@ class OutputProcessor:
finish_reason
=
engine_core_output
.
finish_reason
finish_reason
=
engine_core_output
.
finish_reason
stop_reason
=
engine_core_output
.
stop_reason
stop_reason
=
engine_core_output
.
stop_reason
kv_transfer_params
=
engine_core_output
.
kv_transfer_params
kv_transfer_params
=
engine_core_output
.
kv_transfer_params
num_cached_tokens
=
engine_core_output
.
num_cached_tokens
req_state
.
is_prefilling
=
False
req_state
.
is_prefilling
=
False
# 2) Detokenize the token ids into text and perform stop checks.
# 2) Detokenize the token ids into text and perform stop checks.
...
@@ -356,7 +359,7 @@ class OutputProcessor:
...
@@ -356,7 +359,7 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects.
# 4) Create and handle RequestOutput objects.
if
request_output
:
=
req_state
.
make_request_output
(
if
request_output
:
=
req_state
.
make_request_output
(
new_token_ids
,
finish_reason
,
stop_reason
,
new_token_ids
,
finish_reason
,
stop_reason
,
kv_transfer_params
):
kv_transfer_params
,
num_cached_tokens
):
if
req_state
.
queue
is
not
None
:
if
req_state
.
queue
is
not
None
:
# AsyncLLM: put into queue for handling by generate().
# AsyncLLM: put into queue for handling by generate().
req_state
.
queue
.
put
(
request_output
)
req_state
.
queue
.
put
(
request_output
)
...
...
vllm/v1/executor/multiproc_executor.py
View file @
4eabe123
...
@@ -38,7 +38,7 @@ logger = init_logger(__name__)
...
@@ -38,7 +38,7 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS
=
5000
POLLING_TIMEOUT_MS
=
5000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
EXECUTE_MODEL_TIMEOUT_S
=
4
0
EXECUTE_MODEL_TIMEOUT_S
=
30
0
class
MultiprocExecutor
(
Executor
):
class
MultiprocExecutor
(
Executor
):
...
@@ -50,6 +50,7 @@ class MultiprocExecutor(Executor):
...
@@ -50,6 +50,7 @@ class MultiprocExecutor(Executor):
self
.
is_failed
=
False
self
.
is_failed
=
False
self
.
shutdown_event
=
threading
.
Event
()
self
.
shutdown_event
=
threading
.
Event
()
self
.
failure_callback
:
Optional
[
FailureCallback
]
=
None
self
.
failure_callback
:
Optional
[
FailureCallback
]
=
None
self
.
io_thread_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
self
.
world_size
=
self
.
parallel_config
.
world_size
self
.
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
...
@@ -107,7 +108,6 @@ class MultiprocExecutor(Executor):
...
@@ -107,7 +108,6 @@ class MultiprocExecutor(Executor):
# For pipeline parallel, we use a thread pool for asynchronous
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
# execute_model.
self
.
io_thread_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
if
self
.
max_concurrent_batches
>
1
:
if
self
.
max_concurrent_batches
>
1
:
# Note: must use only 1 IO thread to keep dequeue sequence
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# from the response queue
...
...
vllm/v1/metrics/loggers.py
View file @
4eabe123
...
@@ -200,24 +200,24 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -200,24 +200,24 @@ class PrometheusStatLogger(StatLoggerBase):
# Counters
# Counters
#
#
self
.
counter_num_preempted_reqs
=
self
.
_counter_cls
(
self
.
counter_num_preempted_reqs
=
self
.
_counter_cls
(
name
=
"vllm:num_preemptions
_total
"
,
name
=
"vllm:num_preemptions"
,
documentation
=
"Cumulative number of preemption from the engine."
,
documentation
=
"Cumulative number of preemption from the engine."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_prompt_tokens
=
self
.
_counter_cls
(
self
.
counter_prompt_tokens
=
self
.
_counter_cls
(
name
=
"vllm:prompt_tokens
_total
"
,
name
=
"vllm:prompt_tokens"
,
documentation
=
"Number of prefill tokens processed."
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_generation_tokens
=
self
.
_counter_cls
(
self
.
counter_generation_tokens
=
self
.
_counter_cls
(
name
=
"vllm:generation_tokens
_total
"
,
name
=
"vllm:generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_request_success
:
dict
[
FinishReason
,
self
.
counter_request_success
:
dict
[
FinishReason
,
prometheus_client
.
Counter
]
=
{}
prometheus_client
.
Counter
]
=
{}
counter_request_success_base
=
self
.
_counter_cls
(
counter_request_success_base
=
self
.
_counter_cls
(
name
=
"vllm:request_success
_total
"
,
name
=
"vllm:request_success"
,
documentation
=
"Count of successfully processed requests."
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
"finished_reason"
])
labelnames
=
labelnames
+
[
"finished_reason"
])
for
reason
in
FinishReason
:
for
reason
in
FinishReason
:
...
...
vllm/v1/metrics/reader.py
0 → 100644
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
Optional
from
prometheus_client
import
REGISTRY
from
prometheus_client
import
Metric
as
PromMetric
from
prometheus_client.samples
import
Sample
@
dataclass
class
Metric
:
"""A base class for prometheus metrics.
Each metric may be associated with key=value labels, and
in some cases a single vLLM instance may have multiple
metrics with the same name but different sets of labels.
"""
name
:
str
labels
:
dict
[
str
,
str
]
@
dataclass
class
Counter
(
Metric
):
"""A monotonically increasing integer counter."""
value
:
int
@
dataclass
class
Vector
(
Metric
):
"""An ordered array of integer counters.
This type - which doesn't exist in Prometheus - models one very
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
"""
values
:
list
[
int
]
@
dataclass
class
Gauge
(
Metric
):
"""A numerical value that can go up or down."""
value
:
float
@
dataclass
class
Histogram
(
Metric
):
"""Observations recorded in configurable buckets.
Buckets are represented by a dictionary. The key is
the upper limit of the bucket, and the value is the
observed count in that bucket. A '+Inf' key always
exists.
The count property is the total count across all
buckets, identical to the count of the '+Inf' bucket.
The sum property is the total sum of all observed
values.
"""
count
:
int
sum
:
float
buckets
:
dict
[
str
,
int
]
def
get_metrics_snapshot
()
->
list
[
Metric
]:
"""An API for accessing in-memory Prometheus metrics.
Example:
>>> for metric in llm.get_metrics():
... if isinstance(metric, Counter):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Gauge):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Histogram):
... print(f"{metric}")
... print(f" sum = {metric.sum}")
... print(f" count = {metric.count}")
... for bucket_le, value in metrics.buckets.items():
... print(f" {bucket_le} = {value}")
"""
collected
:
list
[
Metric
]
=
[]
for
metric
in
REGISTRY
.
collect
():
if
not
metric
.
name
.
startswith
(
"vllm:"
):
continue
if
metric
.
type
==
"gauge"
:
samples
=
_get_samples
(
metric
)
for
s
in
samples
:
collected
.
append
(
Gauge
(
name
=
metric
.
name
,
labels
=
s
.
labels
,
value
=
s
.
value
))
elif
metric
.
type
==
"counter"
:
samples
=
_get_samples
(
metric
,
"_total"
)
if
metric
.
name
==
"vllm:spec_decode_num_accepted_tokens_per_pos"
:
#
# Ugly vllm:num_accepted_tokens_per_pos special case.
#
# This metric is a vector of counters - for each spec
# decoding token position, we observe the number of
# accepted tokens using a Counter labeled with 'position'.
# We convert these into a vector of integer values.
#
for
labels
,
values
in
_digest_num_accepted_by_pos_samples
(
samples
):
collected
.
append
(
Vector
(
name
=
metric
.
name
,
labels
=
labels
,
values
=
values
))
else
:
for
s
in
samples
:
collected
.
append
(
Counter
(
name
=
metric
.
name
,
labels
=
s
.
labels
,
value
=
int
(
s
.
value
)))
elif
metric
.
type
==
"histogram"
:
#
# A histogram has a number of '_bucket' samples where
# the 'le' label represents the upper limit of the bucket.
# We convert these bucketized values into a dict of values
# indexed by the value of the 'le' label. The 'le=+Inf'
# label is a special case, catching all values observed.
#
bucket_samples
=
_get_samples
(
metric
,
"_bucket"
)
count_samples
=
_get_samples
(
metric
,
"_count"
)
sum_samples
=
_get_samples
(
metric
,
"_sum"
)
for
labels
,
buckets
,
count_value
,
sum_value
in
_digest_histogram
(
bucket_samples
,
count_samples
,
sum_samples
):
collected
.
append
(
Histogram
(
name
=
metric
.
name
,
labels
=
labels
,
buckets
=
buckets
,
count
=
count_value
,
sum
=
sum_value
))
else
:
raise
AssertionError
(
f
"Unknown metric type
{
metric
.
type
}
"
)
return
collected
def
_get_samples
(
metric
:
PromMetric
,
suffix
:
Optional
[
str
]
=
None
)
->
list
[
Sample
]:
name
=
(
metric
.
name
+
suffix
)
if
suffix
is
not
None
else
metric
.
name
return
[
s
for
s
in
metric
.
samples
if
s
.
name
==
name
]
def
_strip_label
(
labels
:
dict
[
str
,
str
],
key_to_remove
:
str
)
->
dict
[
str
,
str
]:
labels_copy
=
labels
.
copy
()
labels_copy
.
pop
(
key_to_remove
)
return
labels_copy
def
_digest_histogram
(
bucket_samples
:
list
[
Sample
],
count_samples
:
list
[
Sample
],
sum_samples
:
list
[
Sample
]
)
->
list
[
tuple
[
dict
[
str
,
str
],
dict
[
str
,
int
],
int
,
float
]]:
#
# In the case of DP, we have an indigestable
# per-bucket-per-engine count as a list of labelled
# samples, along with total and sum samples
#
# bucket_samples (in):
# labels = {bucket: 100, idx: 0}, value = 2
# labels = {bucket: 200, idx: 0}, value = 4
# labels = {bucket: Inf, idx: 0}, value = 10
# labels = {bucket: 100, idx: 1}, value = 1
# labels = {bucket: 200, idx: 2}, value = 5
# labels = {bucket: Inf, idx: 3}, value = 7
# count_samples (in):
# labels = {idx: 0}, value = 10
# labels = {idx: 1}, value = 7
# sum_samples (in):
# labels = {idx: 0}, value = 2000
# labels = {idx: 1}, value = 1200
#
# output: [
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
# ]
buckets_by_labels
:
dict
[
frozenset
[
tuple
[
str
,
str
]],
dict
[
str
,
int
]]
=
{}
for
s
in
bucket_samples
:
bucket
=
s
.
labels
[
"le"
]
labels_key
=
frozenset
(
_strip_label
(
s
.
labels
,
"le"
).
items
())
if
labels_key
not
in
buckets_by_labels
:
buckets_by_labels
[
labels_key
]
=
{}
buckets_by_labels
[
labels_key
][
bucket
]
=
int
(
s
.
value
)
counts_by_labels
:
dict
[
frozenset
[
tuple
[
str
,
str
]],
int
]
=
{}
for
s
in
count_samples
:
labels_key
=
frozenset
(
s
.
labels
.
items
())
counts_by_labels
[
labels_key
]
=
int
(
s
.
value
)
sums_by_labels
:
dict
[
frozenset
[
tuple
[
str
,
str
]],
float
]
=
{}
for
s
in
sum_samples
:
labels_key
=
frozenset
(
s
.
labels
.
items
())
sums_by_labels
[
labels_key
]
=
s
.
value
assert
set
(
buckets_by_labels
.
keys
())
==
set
(
counts_by_labels
.
keys
())
==
set
(
sums_by_labels
.
keys
())
output
=
[]
label_keys
=
list
(
buckets_by_labels
.
keys
())
for
k
in
label_keys
:
labels
=
dict
(
k
)
output
.
append
((
labels
,
buckets_by_labels
[
k
],
counts_by_labels
[
k
],
sums_by_labels
[
k
]))
return
output
def
_digest_num_accepted_by_pos_samples
(
samples
:
list
[
Sample
])
->
list
[
tuple
[
dict
[
str
,
str
],
list
[
int
]]]:
#
# In the case of DP, we have an indigestable
# per-position-per-engine count as a list of
# labelled samples
#
# samples (in):
# labels = {pos: 0, idx: 0}, value = 10
# labels = {pos: 1, idx: 0}, value = 7
# labels = {pos: 2, idx: 0}, value = 2
# labels = {pos: 0, idx: 1}, value = 5
# labels = {pos: 1, idx: 1}, value = 3
# labels = {pos: 2, idx: 1}, value = 1
#
# output: [
# {idx: 0}, [10, 7, 2]
# {idx: 1}, [5, 3, 1]
# ]
#
max_pos
=
0
values_by_labels
:
dict
[
frozenset
[
tuple
[
str
,
str
]],
dict
[
int
,
int
]]
=
{}
for
s
in
samples
:
position
=
int
(
s
.
labels
[
"position"
])
max_pos
=
max
(
max_pos
,
position
)
labels_key
=
frozenset
(
_strip_label
(
s
.
labels
,
"position"
).
items
())
if
labels_key
not
in
values_by_labels
:
values_by_labels
[
labels_key
]
=
{}
values_by_labels
[
labels_key
][
position
]
=
int
(
s
.
value
)
output
=
[]
for
labels_key
,
values_by_position
in
values_by_labels
.
items
():
labels
=
dict
(
labels_key
)
values
=
[
0
]
*
(
max_pos
+
1
)
for
pos
,
val
in
values_by_position
.
items
():
values
[
pos
]
=
val
output
.
append
((
labels
,
values
))
return
output
vllm/v1/request.py
View file @
4eabe123
...
@@ -77,6 +77,10 @@ class Request:
...
@@ -77,6 +77,10 @@ class Request:
self
.
output_token_ids
=
ConstantList
(
self
.
_output_token_ids
)
self
.
output_token_ids
=
ConstantList
(
self
.
_output_token_ids
)
self
.
all_token_ids
=
ConstantList
(
self
.
_all_token_ids
)
self
.
all_token_ids
=
ConstantList
(
self
.
_all_token_ids
)
# State
# The number of tokens with prefix cache hits.
self
.
num_cached_tokens
=
-
1
@
classmethod
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
if
request
.
mm_inputs
is
not
None
:
if
request
.
mm_inputs
is
not
None
:
...
...
Prev
1
…
28
29
30
31
32
33
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment