Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31f6b24f
Commit
31f6b24f
authored
Mar 26, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori
parents
89d1dd57
25f560a6
Changes
88
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
335 additions
and
154 deletions
+335
-154
vllm/compilation/torch25_custom_graph_pass.py
vllm/compilation/torch25_custom_graph_pass.py
+41
-0
vllm/config.py
vllm/config.py
+12
-17
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+11
-4
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+13
-26
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+23
-15
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+1
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+112
-45
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+5
-2
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+6
-1
vllm/envs.py
vllm/envs.py
+9
-4
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+3
-5
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+5
-9
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+3
-3
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+50
-0
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+4
-4
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+3
-3
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+2
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+30
-12
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
No files found.
vllm/compilation/torch25_custom_graph_pass.py
0 → 100644
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Optional
import
torch
class
Torch25CustomGraphPass
(
ABC
):
# noqa (redefinition)
"""
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
It conforms to the 2.6 interface but also supports pickling, as that's what
the inductor code cache uses to determine the cache key before 2.6.
(in 2.6 and above, uuid() is used.)
Subclasses can just "pretend" that uuid is used.
"""
@
abstractmethod
def
__call__
(
self
,
graph
:
torch
.
fx
.
graph
.
Graph
)
->
None
:
"""
Implementation of the custom pass.
"""
@
abstractmethod
def
uuid
(
self
)
->
Optional
[
Any
]:
"""
Return an ID to uniquely identify your custom pass implementation.
Return None to skip inductor code caching entirely.
"""
def
__getstate__
(
self
):
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
"""
return
self
.
uuid
()
def
__setstate__
(
self
,
state
):
raise
ValueError
(
"Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"
" native uuid support for custom passes."
)
vllm/config.py
View file @
31f6b24f
...
@@ -4,6 +4,7 @@ import ast
...
@@ -4,6 +4,7 @@ import ast
import
copy
import
copy
import
enum
import
enum
import
hashlib
import
hashlib
import
importlib.metadata
import
json
import
json
import
sys
import
sys
import
warnings
import
warnings
...
@@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
...
@@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional
,
Protocol
,
Union
)
Optional
,
Protocol
,
Union
)
import
torch
import
torch
from
packaging.version
import
Version
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -52,8 +54,6 @@ if TYPE_CHECKING:
...
@@ -52,8 +54,6 @@ if TYPE_CHECKING:
else
:
else
:
QuantizationConfig
=
None
QuantizationConfig
=
None
from
packaging.version
import
Version
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# This value is chosen to have a balance between ITL and TTFT. Note it is
...
@@ -1157,10 +1157,6 @@ class CacheConfig:
...
@@ -1157,10 +1157,6 @@ class CacheConfig:
if
self
.
cache_dtype
==
"auto"
:
if
self
.
cache_dtype
==
"auto"
:
pass
pass
elif
self
.
cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
):
elif
self
.
cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
):
if
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"V1 does not yet support fp8 KV cache. "
"Set VLLM_USE_V1=0 to enable fp8 kv cache."
)
logger
.
info
(
logger
.
info
(
"Using fp8 data type to store kv cache. It reduces the GPU "
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"memory footprint and boosts the performance. "
...
@@ -1281,6 +1277,7 @@ class LoadFormat(str, enum.Enum):
...
@@ -1281,6 +1277,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES
=
"bitsandbytes"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
MISTRAL
=
"mistral"
RUNAI_STREAMER
=
"runai_streamer"
RUNAI_STREAMER
=
"runai_streamer"
FASTSAFETENSORS
=
"fastsafetensors"
@
dataclass
@
dataclass
...
@@ -2376,12 +2373,6 @@ class LoRAConfig:
...
@@ -2376,12 +2373,6 @@ class LoRAConfig:
self
.
lora_dtype
=
model_config
.
dtype
self
.
lora_dtype
=
model_config
.
dtype
elif
isinstance
(
self
.
lora_dtype
,
str
):
elif
isinstance
(
self
.
lora_dtype
,
str
):
self
.
lora_dtype
=
getattr
(
torch
,
self
.
lora_dtype
)
self
.
lora_dtype
=
getattr
(
torch
,
self
.
lora_dtype
)
if
model_config
.
quantization
and
model_config
.
quantization
not
in
[
"awq"
,
"gptq"
]:
# TODO support marlin
logger
.
warning
(
"%s quantization is not tested with LoRA yet."
,
model_config
.
quantization
)
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
# Reminder: Please update docs/source/features/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
...
@@ -2809,12 +2800,17 @@ class DecodingConfig:
...
@@ -2809,12 +2800,17 @@ class DecodingConfig:
return
hash_str
return
hash_str
def
__post_init__
(
self
):
def
__post_init__
(
self
):
valid_guided_backends
=
[
v0_
valid_guided_backends
=
[
'outlines'
,
'lm-format-enforcer'
,
'xgrammar'
,
'guidance'
'outlines'
,
'lm-format-enforcer'
,
'xgrammar'
]
]
v1_valid_guided_backends
=
[
'xgrammar'
,
'guidance'
,
'auto'
]
backend
=
GuidedDecodingParams
(
backend
=
GuidedDecodingParams
(
backend
=
self
.
guided_decoding_backend
).
backend_name
backend
=
self
.
guided_decoding_backend
).
backend_name
if
envs
.
VLLM_USE_V1
:
valid_guided_backends
=
v1_valid_guided_backends
else
:
valid_guided_backends
=
v0_valid_guided_backends
if
backend
not
in
valid_guided_backends
:
if
backend
not
in
valid_guided_backends
:
raise
ValueError
(
f
"Invalid guided_decoding_backend '
{
backend
}
',"
raise
ValueError
(
f
"Invalid guided_decoding_backend '
{
backend
}
',"
f
" must be one of
{
valid_guided_backends
}
"
)
f
" must be one of
{
valid_guided_backends
}
"
)
...
@@ -3092,8 +3088,7 @@ class CompilationConfig(BaseModel):
...
@@ -3092,8 +3088,7 @@ class CompilationConfig(BaseModel):
compilation.
compilation.
"""
"""
dict_
=
self
.
model_dump
(
include
=
{
"enable_fusion"
,
"enable_noop"
})
dict_
=
self
.
model_dump
(
include
=
{
"enable_fusion"
,
"enable_noop"
})
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
InductorPass
.
hash_dict
(
dict_
)
return
hashlib
.
sha256
(
encoded
).
digest
()
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
if
not
self
.
enable_noop
and
self
.
enable_fusion
:
if
not
self
.
enable_noop
and
self
.
enable_fusion
:
...
@@ -3182,7 +3177,7 @@ class CompilationConfig(BaseModel):
...
@@ -3182,7 +3177,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here:
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703
# https://github.com/vllm-project/vllm/issues/14703
if
Version
(
torch
.
__version__
)
>=
Version
(
"2.6"
):
if
Version
(
importlib
.
metadata
.
version
(
'torch'
)
)
>=
Version
(
"2.6"
):
KEY
=
'enable_auto_functionalized_v2'
KEY
=
'enable_auto_functionalized_v2'
if
KEY
not
in
self
.
inductor_compile_config
:
if
KEY
not
in
self
.
inductor_compile_config
:
self
.
inductor_compile_config
[
KEY
]
=
False
self
.
inductor_compile_config
[
KEY
]
=
False
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
31f6b24f
...
@@ -233,6 +233,7 @@ class MessageQueue:
...
@@ -233,6 +233,7 @@ class MessageQueue:
if
is_valid_ipv6_address
(
connect_ip
):
if
is_valid_ipv6_address
(
connect_ip
):
self
.
remote_socket
.
setsockopt
(
IPV6
,
1
)
self
.
remote_socket
.
setsockopt
(
IPV6
,
1
)
remote_addr_ipv6
=
True
remote_addr_ipv6
=
True
connect_ip
=
f
"[
{
connect_ip
}
]"
socket_addr
=
f
"tcp://*:
{
remote_subscribe_port
}
"
socket_addr
=
f
"tcp://*:
{
remote_subscribe_port
}
"
self
.
remote_socket
.
bind
(
socket_addr
)
self
.
remote_socket
.
bind
(
socket_addr
)
remote_subscribe_addr
=
f
"tcp://
{
connect_ip
}
:
{
remote_subscribe_port
}
"
remote_subscribe_addr
=
f
"tcp://
{
connect_ip
}
:
{
remote_subscribe_port
}
"
...
@@ -356,8 +357,11 @@ class MessageQueue:
...
@@ -356,8 +357,11 @@ class MessageQueue:
# if we wait for a long time, log a message
# if we wait for a long time, log a message
if
(
time
.
monotonic
()
-
start_time
if
(
time
.
monotonic
()
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
logger
.
debug
(
"No available block found in %s second. "
,
logger
.
debug
(
VLLM_RINGBUFFER_WARNING_INTERVAL
)
(
"No available shared memory broadcast block found"
" in %s second."
),
VLLM_RINGBUFFER_WARNING_INTERVAL
,
)
n_warning
+=
1
n_warning
+=
1
# if we time out, raise an exception
# if we time out, raise an exception
...
@@ -414,8 +418,11 @@ class MessageQueue:
...
@@ -414,8 +418,11 @@ class MessageQueue:
# if we wait for a long time, log a message
# if we wait for a long time, log a message
if
(
time
.
monotonic
()
-
start_time
if
(
time
.
monotonic
()
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
logger
.
debug
(
"No available block found in %s second. "
,
logger
.
debug
(
VLLM_RINGBUFFER_WARNING_INTERVAL
)
(
"No available shared memory broadcast block found"
"in %s second."
),
VLLM_RINGBUFFER_WARNING_INTERVAL
,
)
n_warning
+=
1
n_warning
+=
1
# if we time out, raise an exception
# if we time out, raise an exception
...
...
vllm/distributed/parallel_state.py
View file @
31f6b24f
...
@@ -897,29 +897,22 @@ def initialize_model_parallel(
...
@@ -897,29 +897,22 @@ def initialize_model_parallel(
get_world_group
().
device_group
)
get_world_group
().
device_group
)
data_parallel_size
=
1
data_parallel_size
=
1
has_external_dp
=
False
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
config
=
get_current_vllm_config
()
config
=
get_current_vllm_config
()
if
config
is
not
None
:
if
config
is
not
None
:
if
config
.
parallel_config
.
world_size
!=
world_size
:
# detect external data parallelism.
# dp in vllm means all dp instances need to run together.
# if the world size does not match, it means this dp is external,
# and the dp instances can run independently, e.g. in rlhf workflow
# from https://github.com/volcengine/verl .
# in that case, we treat the rest dimensions as if they are
# data parallel, and create a dummy dp group that is not used.
data_parallel_size
=
world_size
//
(
pipeline_model_parallel_size
*
tensor_model_parallel_size
)
has_external_dp
=
True
else
:
data_parallel_size
=
config
.
parallel_config
.
data_parallel_size
data_parallel_size
=
config
.
parallel_config
.
data_parallel_size
# the layout order is: DP x PP x TP
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
# DP is the data parallel group that is part of the model,
# all the ranks in the same DP group should generate simultaneously,
# i.e. the `generate` call in the same DP group should be called together,
# otherwise it will cause deadlock.
# to get group_ranks for each dimension, transpose that dimension to the
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks
=
torch
.
arange
(
world_size
).
reshape
(
all_ranks
=
torch
.
arange
(
world_size
).
reshape
(
data_parallel_size
,
pipeline_model_parallel_size
,
-
1
,
data_parallel_size
,
pipeline_model_parallel_size
,
tensor_model_parallel_size
)
# noqa
tensor_model_parallel_size
)
# noqa
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
...
@@ -939,7 +932,7 @@ def initialize_model_parallel(
...
@@ -939,7 +932,7 @@ def initialize_model_parallel(
global
_PP
global
_PP
assert
_PP
is
None
,
(
assert
_PP
is
None
,
(
"pipeline model parallel group is already initialized"
)
"pipeline model parallel group is already initialized"
)
group_ranks
=
all_ranks
.
transpose
(
1
,
2
).
reshape
(
group_ranks
=
all_ranks
.
transpose
(
2
,
3
).
reshape
(
-
1
,
pipeline_model_parallel_size
).
unbind
(
0
)
-
1
,
pipeline_model_parallel_size
).
unbind
(
0
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
_PP
=
init_model_parallel_group
(
group_ranks
,
_PP
=
init_model_parallel_group
(
group_ranks
,
...
@@ -949,16 +942,10 @@ def initialize_model_parallel(
...
@@ -949,16 +942,10 @@ def initialize_model_parallel(
global
_DP
global
_DP
assert
_DP
is
None
,
(
"data parallel group is already initialized"
)
assert
_DP
is
None
,
(
"data parallel group is already initialized"
)
group_ranks
=
all_ranks
.
transpose
(
0
,
group_ranks
=
all_ranks
.
transpose
(
1
,
2
).
reshape
(
-
1
,
3
).
reshape
(
-
1
,
data_parallel_size
).
unbind
(
0
)
data_parallel_size
).
unbind
(
0
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
if
has_external_dp
:
# create a dummy dp group that is not used actually,
# since this dp is external.
# a dummy dp group means every rank is a group itself.
# this way, no communication is needed, no memory is wasted.
group_ranks
=
[[
x
]
for
x
in
range
(
world_size
)]
_DP
=
init_model_parallel_group
(
group_ranks
,
_DP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
get_world_group
().
local_rank
,
backend
,
backend
,
...
...
vllm/engine/arg_utils.py
View file @
31f6b24f
...
@@ -391,16 +391,13 @@ class EngineArgs:
...
@@ -391,16 +391,13 @@ class EngineArgs:
default
=
'xgrammar'
,
default
=
'xgrammar'
,
help
=
'Which engine will be used for guided decoding'
help
=
'Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines, '
'https://github.com/mlc-ai/xgrammar and '
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/guidance-ai/llguidance.'
'https://github.com/noamgat/lm-format-enforcer.'
'Valid backend values are "xgrammar", "guidance", and "auto". '
' Can be overridden per request via guided_decoding_backend'
'With "auto", we will make opinionated choices based on request'
' parameter.
\n
'
'contents and what the backend libraries currently support, so '
'Backend-specific options can be supplied in a comma-separated '
'the behavior is subject to change in each release. '
'list following a colon after the backend name. Valid backends and '
'The default is xgrammar.'
)
'all available options are: [xgrammar:no-fallback, '
'xgrammar:disable-any-whitespace, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--logits-processor-pattern'
,
'--logits-processor-pattern'
,
type
=
nullable_str
,
type
=
nullable_str
,
...
@@ -1539,9 +1536,9 @@ class EngineArgs:
...
@@ -1539,9 +1536,9 @@ class EngineArgs:
recommend_to_remove
=
False
)
recommend_to_remove
=
False
)
return
False
return
False
#
Only support Xgrammar for guided decoding so far
.
#
Xgrammar and Guidance are supported
.
SUPPORTED_GUIDED_DECODING
=
[
SUPPORTED_GUIDED_DECODING
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"auto"
]
]
if
self
.
guided_decoding_backend
not
in
SUPPORTED_GUIDED_DECODING
:
if
self
.
guided_decoding_backend
not
in
SUPPORTED_GUIDED_DECODING
:
_raise_or_fallback
(
feature_name
=
"--guided-decoding-backend"
,
_raise_or_fallback
(
feature_name
=
"--guided-decoding-backend"
,
...
@@ -1562,6 +1559,17 @@ class EngineArgs:
...
@@ -1562,6 +1559,17 @@ class EngineArgs:
# No Fp8 KV cache so far.
# No Fp8 KV cache so far.
if
self
.
kv_cache_dtype
!=
"auto"
:
if
self
.
kv_cache_dtype
!=
"auto"
:
fp8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
will_use_fa
=
(
current_platform
.
is_cuda
()
and
not
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
)
or
envs
.
VLLM_ATTENTION_BACKEND
==
"FLASH_ATTN_VLLM_V1"
supported
=
False
if
fp8_attention
and
will_use_fa
:
from
vllm.vllm_flash_attn.fa_utils
import
(
flash_attn_supports_fp8
)
supported
=
flash_attn_supports_fp8
()
if
not
supported
:
_raise_or_fallback
(
feature_name
=
"--kv-cache-dtype"
,
_raise_or_fallback
(
feature_name
=
"--kv-cache-dtype"
,
recommend_to_remove
=
False
)
recommend_to_remove
=
False
)
return
False
return
False
...
...
vllm/engine/async_llm_engine.py
View file @
31f6b24f
...
@@ -545,7 +545,7 @@ async def build_guided_decoding_logits_processor_async(
...
@@ -545,7 +545,7 @@ async def build_guided_decoding_logits_processor_async(
sampling_params
=
copy
.
copy
(
sampling_params
)
sampling_params
=
copy
.
copy
(
sampling_params
)
guided_decoding
=
sampling_params
.
guided_decoding
guided_decoding
=
sampling_params
.
guided_decoding
logger
.
info
(
logger
.
debug
(
"Building guided decoding logits processor. "
"Building guided decoding logits processor. "
"guided_decoding: %s%s"
,
guided_decoding
,
"guided_decoding: %s%s"
,
guided_decoding
,
f
", reasoning_backend:
{
reasoning_backend
}
"
f
", reasoning_backend:
{
reasoning_backend
}
"
...
...
vllm/engine/llm_engine.py
View file @
31f6b24f
...
@@ -1249,7 +1249,7 @@ class LLMEngine:
...
@@ -1249,7 +1249,7 @@ class LLMEngine:
return
None
return
None
def
_advance_to_next_step
(
def
_advance_to_next_step
(
self
,
output
:
List
[
SamplerOutput
]
,
self
,
output
:
SamplerOutput
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
"""Given model output from a single run, append the tokens to the
"""Given model output from a single run, append the tokens to the
...
...
vllm/entrypoints/chat_utils.py
View file @
31f6b24f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
import
codecs
import
json
import
json
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
...
@@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import (
...
@@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio
)
InputAudio
)
# yapf: enable
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
# pydantic needs the TypedDict from typing_extensions
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers
import
(
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
ProcessorMixin
)
from
typing_extensions
import
Required
,
TypeAlias
,
TypedDict
from
typing_extensions
import
Required
,
TypeAlias
,
TypedDict
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -306,24 +306,63 @@ def _detect_content_format(
...
@@ -306,24 +306,63 @@ def _detect_content_format(
return
"openai"
return
"openai"
def
_resolve_hf_chat_template
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
chat_template
:
Optional
[
str
],
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]],
*
,
trust_remote_code
:
bool
,
)
->
Optional
[
str
]:
# 1st priority: The given chat template
if
chat_template
is
not
None
:
return
chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if
tools
is
None
:
try
:
processor
=
cached_get_processor
(
tokenizer
.
name_or_path
,
processor_cls
=
(
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
ProcessorMixin
),
trust_remote_code
=
trust_remote_code
,
)
if
isinstance
(
processor
,
ProcessorMixin
)
and
\
processor
.
chat_template
is
not
None
:
return
processor
.
chat_template
except
Exception
:
logger
.
debug
(
"Failed to load AutoProcessor chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
)
# 3rd priority: AutoTokenizer chat template
try
:
return
tokenizer
.
get_chat_template
(
chat_template
,
tools
=
tools
)
except
Exception
:
logger
.
debug
(
"Failed to load AutoTokenizer chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
)
return
None
def
_resolve_chat_template_content_format
(
def
_resolve_chat_template_content_format
(
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]],
given_format
:
ChatTemplateContentFormatOption
,
given_format
:
ChatTemplateContentFormatOption
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
*
,
trust_remote_code
:
bool
,
)
->
_ChatTemplateContentFormat
:
)
->
_ChatTemplateContentFormat
:
if
isinstance
(
tokenizer
,
(
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)):
if
isinstance
(
tokenizer
,
(
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)):
tokenizer_chat_template
=
tokenizer
.
chat_template
hf_chat_template
=
_resolve_hf_chat_template
(
else
:
tokenizer
,
tokenizer_chat_template
=
None
chat_template
=
chat_template
,
trust_remote_code
=
trust_remote_code
,
jinja_text
:
Optional
[
str
]
tools
=
tools
,
if
isinstance
(
tokenizer_chat_template
,
str
)
and
chat_template
is
None
:
)
jinja_text
=
tokenizer_chat_template
elif
(
isinstance
(
tokenizer_chat_template
,
dict
)
and
chat_template
in
tokenizer_chat_template
):
jinja_text
=
tokenizer_chat_template
[
chat_template
]
else
:
else
:
jinja_text
=
load_chat_template
(
chat_template
,
is_literal
=
True
)
hf_chat_template
=
None
jinja_text
=
(
hf_chat_template
if
isinstance
(
hf_chat_template
,
str
)
else
load_chat_template
(
chat_template
,
is_literal
=
True
))
detected_format
=
(
"string"
if
jinja_text
is
None
else
detected_format
=
(
"string"
if
jinja_text
is
None
else
_detect_content_format
(
jinja_text
,
default
=
"string"
))
_detect_content_format
(
jinja_text
,
default
=
"string"
))
...
@@ -332,17 +371,11 @@ def _resolve_chat_template_content_format(
...
@@ -332,17 +371,11 @@ def _resolve_chat_template_content_format(
@
lru_cache
@
lru_cache
def
resolve
_chat_template_content_format
(
def
_log
_chat_template_content_format
(
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
given_format
:
ChatTemplateContentFormatOption
,
given_format
:
ChatTemplateContentFormatOption
,
tokenizer
:
AnyTokenizer
,
detected_format
:
ChatTemplateContentFormatOption
,
)
->
_ChatTemplateContentFormat
:
):
detected_format
=
_resolve_chat_template_content_format
(
chat_template
,
given_format
,
tokenizer
,
)
logger
.
info
(
logger
.
info
(
"Detected the chat template content format to be '%s'. "
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this."
,
"You can set `--chat-template-content-format` to override this."
,
...
@@ -360,6 +393,29 @@ def resolve_chat_template_content_format(
...
@@ -360,6 +393,29 @@ def resolve_chat_template_content_format(
detected_format
,
detected_format
,
)
)
def
resolve_chat_template_content_format
(
chat_template
:
Optional
[
str
],
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]],
given_format
:
ChatTemplateContentFormatOption
,
tokenizer
:
AnyTokenizer
,
*
,
trust_remote_code
:
bool
=
False
,
)
->
_ChatTemplateContentFormat
:
detected_format
=
_resolve_chat_template_content_format
(
chat_template
,
tools
,
given_format
,
tokenizer
,
trust_remote_code
=
trust_remote_code
,
)
_log_chat_template_content_format
(
chat_template
,
given_format
=
given_format
,
detected_format
=
detected_format
,
)
return
detected_format
return
detected_format
...
@@ -500,11 +556,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
...
@@ -500,11 +556,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
raise
ValueError
(
\
raise
ValueError
(
\
"Only one message can have {'type': 'image_embeds'}"
)
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
el
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
el
if
"audio"
in
items_by_modality
:
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
el
if
"video"
in
items_by_modality
:
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
return
mm_inputs
...
@@ -533,11 +589,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
...
@@ -533,11 +589,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
raise
ValueError
(
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
el
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
el
if
"audio"
in
items_by_modality
:
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
el
if
"video"
in
items_by_modality
:
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
return
mm_inputs
...
@@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
...
@@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
def
load_chat_template
(
def
_
load_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]],
chat_template
:
Optional
[
Union
[
Path
,
str
]],
*
,
*
,
is_literal
:
bool
=
False
,
is_literal
:
bool
=
False
,
...
@@ -724,7 +780,7 @@ def load_chat_template(
...
@@ -724,7 +780,7 @@ def load_chat_template(
raise
TypeError
(
"chat_template is expected to be read directly "
raise
TypeError
(
"chat_template is expected to be read directly "
"from its value"
)
"from its value"
)
return
c
odecs
.
decode
(
chat_template
,
"unicode_escape"
)
return
c
hat_template
try
:
try
:
with
open
(
chat_template
)
as
f
:
with
open
(
chat_template
)
as
f
:
...
@@ -742,7 +798,18 @@ def load_chat_template(
...
@@ -742,7 +798,18 @@ def load_chat_template(
# If opening a file fails, set chat template to be args to
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
# ensure we decode so our escape are interpreted correctly
return
load_chat_template
(
chat_template
,
is_literal
=
True
)
return
_load_chat_template
(
chat_template
,
is_literal
=
True
)
_cached_load_chat_template
=
lru_cache
(
_load_chat_template
)
def
load_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]],
*
,
is_literal
:
bool
=
False
,
)
->
Optional
[
str
]:
return
_cached_load_chat_template
(
chat_template
,
is_literal
=
is_literal
)
# TODO: Let user specify how to insert multimodal tokens into prompt
# TODO: Let user specify how to insert multimodal tokens into prompt
...
@@ -1067,23 +1134,20 @@ def apply_hf_chat_template(
...
@@ -1067,23 +1134,20 @@ def apply_hf_chat_template(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
conversation
:
list
[
ConversationMessage
],
conversation
:
list
[
ConversationMessage
],
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]],
*
,
*
,
trust_remote_code
:
bool
=
False
,
tokenize
:
bool
=
False
,
# Different from HF's default
tokenize
:
bool
=
False
,
# Different from HF's default
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
str
:
)
->
str
:
if
chat_template
is
None
:
hf_chat_template
=
_resolve_hf_chat_template
(
chat_template
=
tokenizer
.
chat_template
tokenizer
,
chat_template
=
chat_template
,
# FIXME: Temporary workaround for
tools
=
tools
,
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
trust_remote_code
=
trust_remote_code
,
if
chat_template
is
None
:
)
try
:
processor
=
cached_get_processor
(
tokenizer
.
name_or_path
)
chat_template
=
processor
.
chat_template
except
Exception
:
pass
if
chat_template
is
None
:
if
hf_
chat_template
is
None
:
raise
ValueError
(
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"allowed, so you must provide a chat template if the tokenizer "
...
@@ -1091,7 +1155,8 @@ def apply_hf_chat_template(
...
@@ -1091,7 +1155,8 @@ def apply_hf_chat_template(
return
tokenizer
.
apply_chat_template
(
return
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
# type: ignore[arg-type]
conversation
=
conversation
,
# type: ignore[arg-type]
chat_template
=
chat_template
,
tools
=
tools
,
# type: ignore[arg-type]
chat_template
=
hf_chat_template
,
tokenize
=
tokenize
,
tokenize
=
tokenize
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -1100,7 +1165,8 @@ def apply_hf_chat_template(
...
@@ -1100,7 +1165,8 @@ def apply_hf_chat_template(
def
apply_mistral_chat_template
(
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
tokenizer
:
MistralTokenizer
,
messages
:
list
[
ChatCompletionMessageParam
],
messages
:
list
[
ChatCompletionMessageParam
],
chat_template
:
Optional
[
str
]
=
None
,
chat_template
:
Optional
[
str
],
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]],
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
list
[
int
]:
)
->
list
[
int
]:
if
chat_template
is
not
None
:
if
chat_template
is
not
None
:
...
@@ -1117,5 +1183,6 @@ def apply_mistral_chat_template(
...
@@ -1117,5 +1183,6 @@ def apply_mistral_chat_template(
return
tokenizer
.
apply_chat_template
(
return
tokenizer
.
apply_chat_template
(
messages
=
messages
,
messages
=
messages
,
tools
=
tools
,
**
kwargs
,
**
kwargs
,
)
)
vllm/entrypoints/llm.py
View file @
31f6b24f
...
@@ -690,8 +690,10 @@ class LLM:
...
@@ -690,8 +690,10 @@ class LLM:
model_config
=
self
.
llm_engine
.
get_model_config
()
model_config
=
self
.
llm_engine
.
get_model_config
()
resolved_content_format
=
resolve_chat_template_content_format
(
resolved_content_format
=
resolve_chat_template_content_format
(
chat_template
,
chat_template
,
tools
,
chat_template_content_format
,
chat_template_content_format
,
tokenizer
,
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
prompts
:
list
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
prompts
:
list
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
...
@@ -713,18 +715,19 @@ class LLM:
...
@@ -713,18 +715,19 @@ class LLM:
tokenizer
,
tokenizer
,
messages
=
msgs
,
messages
=
msgs
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
tools
=
tools
,
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
)
)
else
:
else
:
prompt_data
=
apply_hf_chat_template
(
prompt_data
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
tools
=
tools
,
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
)
)
prompt
:
Union
[
TokensPrompt
,
TextPrompt
]
prompt
:
Union
[
TokensPrompt
,
TextPrompt
]
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
31f6b24f
...
@@ -379,14 +379,18 @@ class OpenAIServing:
...
@@ -379,14 +379,18 @@ class OpenAIServing:
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
)
->
tuple
[
list
[
ConversationMessage
],
Sequence
[
RequestPrompt
],
)
->
tuple
[
list
[
ConversationMessage
],
Sequence
[
RequestPrompt
],
list
[
TokensPrompt
]]:
list
[
TokensPrompt
]]:
model_config
=
self
.
model_config
resolved_content_format
=
resolve_chat_template_content_format
(
resolved_content_format
=
resolve_chat_template_content_format
(
chat_template
,
chat_template
,
tool_dicts
,
chat_template_content_format
,
chat_template_content_format
,
tokenizer
,
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
conversation
,
mm_data_future
=
parse_chat_messages_futures
(
conversation
,
mm_data_future
=
parse_chat_messages_futures
(
messages
,
messages
,
self
.
model_config
,
model_config
,
tokenizer
,
tokenizer
,
content_format
=
resolved_content_format
,
content_format
=
resolved_content_format
,
)
)
...
@@ -410,6 +414,7 @@ class OpenAIServing:
...
@@ -410,6 +414,7 @@ class OpenAIServing:
else
:
else
:
request_prompt
=
apply_hf_chat_template
(
request_prompt
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
conversation
=
conversation
,
conversation
=
conversation
,
**
_chat_template_kwargs
,
**
_chat_template_kwargs
,
)
)
...
...
vllm/envs.py
View file @
31f6b24f
...
@@ -75,6 +75,7 @@ if TYPE_CHECKING:
...
@@ -75,6 +75,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER
:
bool
=
False
VLLM_ROCM_USE_AITER
:
bool
=
False
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ROCM_MOE_PADDING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_DISABLE_COMPILE_CACHE
:
bool
=
False
VLLM_DISABLE_COMPILE_CACHE
:
bool
=
False
...
@@ -294,7 +295,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -294,7 +295,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# this is used for configuring the default logging level
# this is used for configuring the default logging level
"VLLM_LOGGING_LEVEL"
:
"VLLM_LOGGING_LEVEL"
:
lambda
:
os
.
getenv
(
"VLLM_LOGGING_LEVEL"
,
"INFO"
),
lambda
:
os
.
getenv
(
"VLLM_LOGGING_LEVEL"
,
"INFO"
)
.
upper
()
,
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
"VLLM_LOGGING_PREFIX"
:
"VLLM_LOGGING_PREFIX"
:
...
@@ -340,7 +341,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -340,7 +341,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
# (CPU backend only) CPU key-value cache space.
# (CPU backend only) CPU key-value cache space.
# default is 4
G
B
# default is 4
Gi
B
"VLLM_CPU_KVCACHE_SPACE"
:
"VLLM_CPU_KVCACHE_SPACE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_CPU_KVCACHE_SPACE"
,
"0"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_CPU_KVCACHE_SPACE"
,
"0"
)),
...
@@ -412,9 +413,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -412,9 +413,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
int
(
os
.
getenv
(
"VLLM_AUDIO_FETCH_TIMEOUT"
,
"10"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_AUDIO_FETCH_TIMEOUT"
,
"10"
)),
# Cache size (in GiB) for multimodal input cache
# Cache size (in GiB) for multimodal input cache
# Default is
8
GiB
# Default is
4
GiB
"VLLM_MM_INPUT_CACHE_GIB"
:
"VLLM_MM_INPUT_CACHE_GIB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MM_INPUT_CACHE_GIB"
,
"
8
"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_MM_INPUT_CACHE_GIB"
,
"
4
"
)),
# Path to the XLA persistent cache directory.
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
# Only used for XLA devices such as TPUs.
...
@@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_FP8_PADDING"
:
"VLLM_ROCM_FP8_PADDING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_FP8_PADDING"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_FP8_PADDING"
,
"1"
))),
# Pad the weights for the moe kernel
"VLLM_ROCM_MOE_PADDING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_MOE_PADDING"
,
"1"
))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT"
:
"Q_SCALE_CONSTANT"
:
lambda
:
int
(
os
.
getenv
(
"Q_SCALE_CONSTANT"
,
"200"
)),
lambda
:
int
(
os
.
getenv
(
"Q_SCALE_CONSTANT"
,
"200"
)),
...
...
vllm/executor/ray_utils.py
View file @
31f6b24f
...
@@ -289,16 +289,14 @@ def initialize_ray_cluster(
...
@@ -289,16 +289,14 @@ def initialize_ray_cluster(
elif
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
elif
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
# Try to connect existing ray instance and create a new one if not found
# Try to connect existing ray instance and create a new one if not found
try
:
try
:
ray
.
init
(
"auto"
,
ignore_reinit_error
=
True
)
ray
.
init
(
"auto"
)
except
ConnectionError
:
except
ConnectionError
:
logger
.
warning
(
logger
.
warning
(
"No existing RAY instance detected. "
"No existing RAY instance detected. "
"A new instance will be launched with current node resources."
)
"A new instance will be launched with current node resources."
)
ray
.
init
(
address
=
ray_address
,
ray
.
init
(
address
=
ray_address
,
num_gpus
=
parallel_config
.
world_size
)
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
else
:
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
ray
.
init
(
address
=
ray_address
)
device_str
=
current_platform
.
ray_device_key
device_str
=
current_platform
.
ray_device_key
if
not
device_str
:
if
not
device_str
:
...
...
vllm/lora/punica_wrapper/punica_gpu.py
View file @
31f6b24f
...
@@ -78,10 +78,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -78,10 +78,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...],
scale
:
float
,
**
kwargs
):
...],
scale
:
float
,
**
kwargs
):
"""
"""
Performs GEMM for multiple slices of lora_a.
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
Semantics:
for i in range(len(lora_a_stacked)):
for i in range(len(lora_a_stacked)):
...
@@ -226,7 +222,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -226,7 +222,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
if
buffer
is
None
:
if
buffer
is
None
:
r
=
lora_b_stacked
[
0
].
size
(
-
1
)
r
=
lora_b_stacked
[
0
].
size
(
-
1
)
# We set the buffer to be float32 by default
,refer to:
# We set the buffer to be float32 by default,
refer to:
# https://github.com/triton-lang/triton/issues/1387
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
(
# type: ignore
buffer
=
torch
.
zeros
(
# type: ignore
(
len
(
output_slices
),
x
.
size
(
0
),
r
),
(
len
(
output_slices
),
x
.
size
(
0
),
r
),
...
@@ -268,16 +264,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -268,16 +264,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y (torch.Tensor): Output tensor.
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
lora_b_stacked (torch.Tensor):
lora_b's weights.
scale (float): Scaling factor.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
buffer (Optional[torch.Tensor]):
Default to None.
"""
"""
y_org
=
y
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
lora_b_stacked
.
size
(
-
1
)
r
=
lora_b_stacked
.
size
(
-
1
)
if
buffer
is
None
:
if
buffer
is
None
:
# We set the buffer to be float32 by default
,refer to:
# We set the buffer to be float32 by default,
refer to:
# https://github.com/triton-lang/triton/issues/1387
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
31f6b24f
...
@@ -815,7 +815,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -815,7 +815,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
B
.
shape
[
1
]
if
not
use_nn_moe
else
B
.
shape
[
2
],
B
.
shape
[
1
]
if
not
use_nn_moe
else
B
.
shape
[
2
],
A
.
shape
[
1
],
A
.
shape
[
2
],
EM
,
EM
,
topk_ids
.
numel
(),
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
0
),
...
@@ -1355,8 +1355,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1355,8 +1355,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous
"
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1
"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous
"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1
"
assert
hidden_states
.
dtype
in
[
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
]
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
31f6b24f
...
@@ -6,6 +6,7 @@ from enum import Enum
...
@@ -6,6 +6,7 @@ from enum import Enum
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
UninitializedParameter
from
torch.nn.parameter
import
UninitializedParameter
from
vllm
import
envs
from
vllm
import
envs
...
@@ -111,9 +112,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -111,9 +112,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
_maybe_pad_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if
(
envs
.
VLLM_ROCM_MOE_PADDING
and
current_platform
.
is_rocm
()
and
weight
.
stride
(
-
1
)
==
1
and
(
weight
.
stride
(
-
2
)
*
weight
.
element_size
())
%
512
==
0
):
num_pad
=
256
//
weight
.
element_size
()
weight
=
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
torch
.
cuda
.
empty_cache
()
return
weight
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
super
().
process_weights_after_loading
(
layer
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
)
if
current_platform
.
is_cpu
():
if
current_platform
.
is_cpu
():
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
import
intel_extension_for_pytorch
as
ipex
import
intel_extension_for_pytorch
as
ipex
...
@@ -233,6 +252,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -233,6 +252,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias
,
e_score_correction_bias
,
)
)
def
forward_hpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
assert
layer
is
not
None
if
scoring_func
!=
"softmax"
:
raise
NotImplementedError
(
"Only softmax scoring function is supported for HPU."
)
if
e_score_correction_bias
is
not
None
:
raise
NotImplementedError
(
"Expert score correction bias is not supported for HPU."
)
return
layer
.
hpu_fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
)
def
forward_tpu
(
def
forward_tpu
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -432,6 +479,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -432,6 +479,9 @@ class FusedMoE(torch.nn.Module):
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
raise
ValueError
(
"Only softmax scoring function is supported for "
"non-grouped topk."
)
"non-grouped topk."
)
if
current_platform
.
is_hpu
():
from
vllm_hpu_extension.ops
import
DynamicFusedMOE
self
.
hpu_fused_moe
=
DynamicFusedMOE
(
self
.
num_experts
)
# Note: get_quant_method will look at the layer's local_num_experts
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
# for heuristic purposes, so it must be initialized first.
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
31f6b24f
...
@@ -155,12 +155,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
...
@@ -155,12 +155,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
BitsAndBytesConfig
):
def
__init__
(
self
,
quant_config
:
BitsAndBytesConfig
):
try
:
try
:
import
bitsandbytes
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.45.
0
"
:
if
bitsandbytes
.
__version__
<
"0.45.
3
"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.
0
."
)
"install bitsandbytes>=0.45.
3
."
)
except
ImportError
as
err
:
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.45.
0
via "
raise
ImportError
(
"Please install bitsandbytes>=0.45.
3
via "
"`pip install bitsandbytes>=0.45.
0
` to use "
"`pip install bitsandbytes>=0.45.
3
` to use "
"bitsandbytes quantizer."
)
from
err
"bitsandbytes quantizer."
)
from
err
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
31f6b24f
...
@@ -255,7 +255,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -255,7 +255,7 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
def
add_padding_to
_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_maybe_pad
_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
# can benefit from tensors located far enough from one another in memory
if
(
envs
.
VLLM_ROCM_FP8_PADDING
and
current_platform
.
is_rocm
()
if
(
envs
.
VLLM_ROCM_FP8_PADDING
and
current_platform
.
is_rocm
()
...
@@ -279,7 +279,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -279,7 +279,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
.
data
weight
=
layer
.
weight
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight
=
self
.
add_padding_to
_weight
(
weight
)
weight
=
self
.
_maybe_pad
_weight
(
weight
)
# Torch.compile cannot use Parameter subclasses.
# Torch.compile cannot use Parameter subclasses.
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
...
@@ -343,7 +343,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -343,7 +343,7 @@ class Fp8LinearMethod(LinearMethodBase):
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
)
)
weight
=
self
.
add_padding_to
_weight
(
weight
)
weight
=
self
.
_maybe_pad
_weight
(
weight
)
# Update layer with new values.
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
...
vllm/model_executor/layers/sampler.py
View file @
31f6b24f
...
@@ -1187,7 +1187,8 @@ def _build_sampler_output(
...
@@ -1187,7 +1187,8 @@ def _build_sampler_output(
deferred_sample_results_args
=
deferred_sample_results_args
)
deferred_sample_results_args
=
deferred_sample_results_args
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
tuple
[
int
,
...]:
"""Get a list of next prompt tokens to compute logprob from a
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
given sequence group.
...
...
vllm/model_executor/model_loader/loader.py
View file @
31f6b24f
...
@@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
...
@@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
set_default_torch_dtype
)
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
fastsafetensors_weights_iterator
,
filter_duplicate_safetensors_files
,
get_gguf_extra_tensor_names
,
get_lock
,
gguf_quant_weights_iterator
,
filter_files_not_needed_for_inference
,
get_gguf_extra_tensor_names
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
get_lock
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
runai_safetensors_weights_iterator
,
safetensors_weights_iterator
)
runai_safetensors_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader):
# Some quantized models use .pt files for storing the weights.
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
elif
(
load_format
==
LoadFormat
.
SAFETENSORS
or
load_format
==
LoadFormat
.
FASTSAFETENSORS
):
use_safetensors
=
True
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
elif
load_format
==
LoadFormat
.
MISTRAL
:
...
@@ -357,6 +359,12 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -357,6 +359,12 @@ class DefaultModelLoader(BaseModelLoader):
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
use_tqdm_on_load
,
)
)
elif
use_safetensors
:
elif
use_safetensors
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
FASTSAFETENSORS
:
weights_iterator
=
fastsafetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
weights_iterator
=
safetensors_weights_iterator
(
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
use_tqdm_on_load
,
...
@@ -379,6 +387,16 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -379,6 +387,16 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
elif
current_platform
.
is_hpu
():
import
habana_frameworks.torch.core
as
htcore
def
_hpu_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
htcore
.
mark_step
()
weights_iterator
=
_hpu_weights_iterator
(
weights_iterator
)
if
self
.
counter_before_loading_weights
==
0.0
:
if
self
.
counter_before_loading_weights
==
0.0
:
self
.
counter_before_loading_weights
=
time
.
perf_counter
()
self
.
counter_before_loading_weights
=
time
.
perf_counter
()
# Apply the prefix.
# Apply the prefix.
...
@@ -862,12 +880,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -862,12 +880,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
try
:
try
:
import
bitsandbytes
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.45.
0
"
:
if
bitsandbytes
.
__version__
<
"0.45.
3
"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.
0
."
)
"install bitsandbytes>=0.45.
3
."
)
except
ImportError
as
err
:
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.45.
0
via "
raise
ImportError
(
"Please install bitsandbytes>=0.45.
3
via "
"`pip install bitsandbytes>=0.45.
0
` to use "
"`pip install bitsandbytes>=0.45.
3
` to use "
"bitsandbytes quantizer."
)
from
err
"bitsandbytes quantizer."
)
from
err
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
...
...
vllm/model_executor/model_loader/utils.py
View file @
31f6b24f
...
@@ -32,7 +32,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
...
@@ -32,7 +32,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
is_transformers_impl_compatible
(
def
is_transformers_impl_compatible
(
arch
:
str
,
arch
:
str
,
module
:
Optional
[
transformers
.
PreTrainedModel
]
=
None
)
->
bool
:
module
:
Optional
[
"
transformers.PreTrainedModel
"
]
=
None
)
->
bool
:
mod
=
module
or
getattr
(
transformers
,
arch
,
None
)
mod
=
module
or
getattr
(
transformers
,
arch
,
None
)
if
mod
is
None
:
if
mod
is
None
:
return
False
return
False
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment