Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
587
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
372 additions
and
163 deletions
+372
-163
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+53
-9
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+32
-11
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+2
-2
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+2
-2
vllm/engine/protocol.py
vllm/engine/protocol.py
+10
-0
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+69
-81
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+204
-58
No files found.
Too many changes to show.
To preserve performance only
587 of 587+
files are displayed.
Plain diff
Email patch
vllm/engine/multiprocessing/client.py
View file @
afd0da21
...
...
@@ -25,7 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCAdapterLoadedResponse
,
RPCError
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
from
vllm.engine.protocol
import
EngineClient
...
...
@@ -240,22 +243,34 @@ class MQLLMEngineClient(EngineClient):
queue
=
self
.
output_queues
.
get
(
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
exception
)
# Put each output into the appropriate queue.
elif
isinstance
(
request_outputs
,
RPCAdapterLoadedResponse
):
self
.
_add_output
(
request_outputs
)
else
:
# Put each output into the appropriate steam.
for
request_output
in
request_outputs
:
queue
=
self
.
output_queues
.
get
(
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
self
.
_add_output
(
request_output
)
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient output handler."
)
def
_add_output
(
self
,
request_output
:
Union
[
RequestOutput
,
RPCAdapterLoadedResponse
]):
queue
=
self
.
output_queues
.
get
(
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
# Start output_loop
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
if
self
.
output_loop
is
None
:
# only generate once to avoid multiple concurrent output_loops
# this will lead to race conditions and wrong orders of tokens
# returned by the engine
# setup will be called multiple times during the startup of
# the engine
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
with
self
.
get_data_socket
()
as
socket
:
# Wait until server is ready.
...
...
@@ -264,8 +279,9 @@ class MQLLMEngineClient(EngineClient):
self
.
tracing_flag
=
response
.
tracing_enabled
# Start health_loop.
self
.
health_loop
=
asyncio
.
create_task
(
self
.
run_heartbeat_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
if
self
.
health_loop
is
None
:
self
.
health_loop
=
asyncio
.
create_task
(
self
.
run_heartbeat_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
...
...
@@ -659,3 +675,31 @@ class MQLLMEngineClient(EngineClient):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
async
def
reset_prefix_cache
(
self
)
->
None
:
"""Reset the prefix cache"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCResetPrefixCacheRequest
.
RESET_PREFIX_CACHE
,
socket
=
self
.
input_socket
)
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request
=
RPCLoadAdapterRequest
(
lora_request
)
# Create output queue for this requests.
queue
:
asyncio
.
Queue
[
Union
[
None
,
BaseException
]]
=
asyncio
.
Queue
()
self
.
output_queues
[
request
.
request_id
]
=
queue
# Send the request
request_bytes
=
pickle
.
dumps
(
request
)
await
self
.
input_socket
.
send_multipart
((
request_bytes
,
),
copy
=
False
)
# Wait for the response
request_output
=
await
queue
.
get
()
self
.
output_queues
.
pop
(
request
.
request_id
)
# Raise on error, otherwise happily return None
if
isinstance
(
request_output
,
BaseException
):
raise
request_output
vllm/engine/multiprocessing/engine.py
View file @
afd0da21
...
...
@@ -14,11 +14,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCAdapterLoadedResponse
,
RPCError
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
# yapf: enable
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
...
...
@@ -234,6 +236,10 @@ class MQLLMEngine:
self
.
start_profile
()
else
:
self
.
stop_profile
()
elif
isinstance
(
request
,
RPCLoadAdapterRequest
):
self
.
_handle_load_adapter_request
(
request
)
elif
isinstance
(
request
,
RPCResetPrefixCacheRequest
):
self
.
reset_prefix_cache
()
else
:
raise
ValueError
(
"Unknown RPCRequest Type: "
f
"
{
type
(
request
)
}
"
)
...
...
@@ -284,6 +290,20 @@ class MQLLMEngine:
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
def
_handle_load_adapter_request
(
self
,
request
:
RPCLoadAdapterRequest
):
try
:
self
.
engine
.
add_lora
(
request
.
lora_request
)
except
BaseException
as
e
:
# Send back an error if the adater fails to load
rpc_err
=
RPCError
(
request_id
=
request
.
request_id
,
is_engine_errored
=
False
,
exception
=
e
)
self
.
_send_outputs
(
rpc_err
)
return
# Otherwise, send back the successful load message
self
.
_send_outputs
(
RPCAdapterLoadedResponse
(
request_id
=
request
.
request_id
))
def
_health_check
(
self
):
# Send unhealthy if engine has already errored
if
self
.
_errored_with
is
not
None
:
...
...
@@ -296,7 +316,11 @@ class MQLLMEngine:
self
.
_send_unhealthy
(
e
)
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
"""Send List of RequestOutput to RPCClient."""
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if
outputs
:
try
:
from
ray.exceptions
import
RayTaskError
...
...
@@ -335,16 +359,13 @@ class MQLLMEngine:
self
.
_errored_with
=
e
def
start_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
start_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
self
.
engine
.
start_profile
()
def
stop_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
stop_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop
_pr
o
fi
le"
)
self
.
engine
.
stop_profile
()
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
engine
.
reset
_pr
e
fi
x_cache
(
)
def
signal_handler
(
*
_
)
->
None
:
...
...
vllm/engine/output_processor/multi_step.py
View file @
afd0da21
...
...
@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@
staticmethod
@
functools
.
lru_cache
def
_log_prompt_logprob_unsupported_warning_once
():
# Reminder: Please update docs/source/
usage
/compatibility_matrix.md
# Reminder: Please update docs/source/
features
/compatibility_matrix.md
# If the feature combo become valid
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
...
...
@@ -144,7 +144,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
new_char_count
=
0
if
sampling_params
.
detokenize
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
...
...
vllm/engine/output_processor/single_step.py
View file @
afd0da21
...
...
@@ -102,9 +102,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
output
s
: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only ha
s
1 output."
)
assert
len
(
outputs
)
==
1
,
"Single step should only ha
ve
1 output."
output
=
outputs
[
0
]
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
...
...
vllm/engine/protocol.py
View file @
afd0da21
...
...
@@ -270,3 +270,13 @@ class EngineClient(ABC):
async
def
stop_profile
(
self
)
->
None
:
"""Start profiling the engine"""
...
@
abstractmethod
async
def
reset_prefix_cache
(
self
)
->
None
:
"""Reset the prefix cache"""
...
@
abstractmethod
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
...
vllm/entrypoints/chat_utils.py
View file @
afd0da21
...
...
@@ -3,10 +3,10 @@ import codecs
import
json
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
,
deque
from
functools
import
lru_cache
,
partial
from
functools
import
cache
,
lru_cache
,
partial
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
import
jinja2.nodes
import
transformers.utils.chat_template_utils
as
hf_chat_utils
...
...
@@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam
as
OpenAIChatCompletionMessageParam
)
from
openai.types.chat
import
(
ChatCompletionMessageToolCallParam
,
ChatCompletionToolMessageParam
)
from
openai.types.chat.chat_completion_content_part_input_audio_param
import
(
InputAudio
)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
...
...
@@ -31,13 +33,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
async_get_and_parse_image
,
async_get_and_parse_video
,
get_and_parse_audio
,
get_and_parse_image
,
get_and_parse_video
)
from
vllm.multimodal.utils
import
MediaConnector
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
...
...
@@ -368,16 +365,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
self
.
_consumed_items
=
{
k
:
0
for
k
in
self
.
_allowed_items
}
self
.
_items
:
List
[
_T
]
=
[]
self
.
_items
_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
@
property
def
model_config
(
self
)
->
ModelConfig
:
return
self
.
_model_config
@
property
def
allowed_local_media_path
(
self
):
return
self
.
_model_config
.
allowed_local_media_path
@
staticmethod
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
return
tokenizer
.
decode
(
token_index
)
...
...
@@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
f
"<|image_
{
current_count
}
|>"
if
model_type
==
"minicpmv"
:
if
model_type
in
(
"minicpmo"
,
"minicpmv"
)
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
,
"pixtral"
):
...
...
@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
,
"NVLM_D"
,
"h2ovl_chat"
):
if
model_type
in
(
"chameleon"
,
"deepseek_vl_v2"
,
"internvl_chat"
,
"NVLM_D"
,
"h2ovl_chat"
):
return
"<image>"
if
model_type
==
"mllama"
:
return
"<|image|>"
...
...
@@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
==
"qwen2_audio"
:
return
(
f
"Audio
{
current_count
}
: "
f
"<|audio_bos|><|AUDIO|><|audio_eos|>"
)
if
model_type
==
"minicpmo"
:
return
"(<audio>./</audio>)"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"video"
:
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|video_pad|><|vision_end|>"
if
model_type
in
(
"minicpmo"
,
"minicpmv"
):
return
"(<video>./</video>)"
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
video_token_index
)
...
...
@@ -435,38 +439,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
@
staticmethod
def
_combine
(
items
:
List
[
MultiModalDataDict
])
->
MultiModalDataDict
:
mm_lists
:
Mapping
[
str
,
List
[
object
]]
=
defaultdict
(
list
)
# Merge all the multi-modal items
for
single_mm_data
in
items
:
for
mm_key
,
mm_item
in
single_mm_data
.
items
():
if
isinstance
(
mm_item
,
list
):
mm_lists
[
mm_key
].
extend
(
mm_item
)
else
:
mm_lists
[
mm_key
].
append
(
mm_item
)
# Unpack any single item lists for models that don't expect multiple.
return
{
mm_key
:
mm_list
[
0
]
if
len
(
mm_list
)
==
1
else
mm_list
for
mm_key
,
mm_list
in
mm_lists
.
items
()
}
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
)
->
Optional
[
str
]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
current_count
=
self
.
_
consumed_items
.
get
(
modality
,
0
)
+
1
current_count
=
len
(
self
.
_
items_by_modality
[
modality
]
)
+
1
if
current_count
>
allowed_count
:
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
self
.
_consumed_items
[
modality
]
=
current_count
self
.
_items
.
append
(
item
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
return
self
.
_placeholder_str
(
modality
,
current_count
)
...
...
@@ -475,22 +460,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise
NotImplementedError
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
MultiModalDataDi
ct
]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
obje
ct
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
return
self
.
_combine
(
self
.
_items
)
if
self
.
_items
else
None
if
self
.
_items_by_modality
:
return
dict
(
self
.
_items_by_modality
)
return
None
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
MultiModalContentParser
(
self
)
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
MultiModalDataDict
]]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items
:
items
=
await
asyncio
.
gather
(
*
self
.
_items
)
return
self
.
_combine
(
items
)
if
self
.
_items_by_modality
:
return
{
modality
:
await
asyncio
.
gather
(
*
items
)
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
return
None
...
...
@@ -522,7 +511,7 @@ class BaseMultiModalContentParser(ABC):
raise
NotImplementedError
@
abstractmethod
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
]
)
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
raise
NotImplementedError
@
abstractmethod
...
...
@@ -537,31 +526,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image
=
get_and_parse_image
(
image_url
,
allowed_local_media_path
=
self
.
_tracker
.
_model_config
.
allowed_local_media_path
)
image
=
self
.
_connector
.
fetch_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
get_and_parse
_audio
(
audio_url
)
audio
=
self
.
_connector
.
fetch
_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
])
->
None
:
input_audio_data
=
input_audio
.
get
(
"data"
,
""
)
input_audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
input_audio_format
}
;base64,
{
input_audio_data
}
"
audio
=
get_and_parse_audio
(
audio_url
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
return
self
.
parse_audio
(
audio_url
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
video
=
get_and_parse
_video
(
video_url
)
video
=
self
.
_connector
.
fetch
_video
(
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
...
...
@@ -573,33 +562,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super
().
__init__
()
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image_coro
=
async_get_and_parse_image
(
image_url
,
allowed_local_media_path
=
self
.
_tracker
.
_model_config
.
allowed_local_media_path
)
image_coro
=
self
.
_connector
.
fetch_image_async
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
])
->
None
:
input_audio_data
=
input_audio
.
get
(
"data"
,
""
)
input_audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
input_audio_format
}
;base64,
{
input_audio_data
}
"
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
return
self
.
parse_audio
(
audio_url
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
video
=
async_get_and_parse_video
(
video_url
)
video
=
self
.
_connector
.
fetch_video_async
(
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
...
...
@@ -695,10 +682,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_VideoParser
=
partial
(
cast
,
ChatCompletionContentPartVideoParam
)
_ContentPart
:
TypeAlias
=
Union
[
str
,
Dict
[
str
,
str
],
InputAudio
]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP
:
Dict
[
str
,
Callable
[[
ChatCompletionContentPartParam
],
Union
[
str
,
Dict
[
str
,
str
]]]]
=
{
MM_PARSER_MAP
:
Dict
[
str
,
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
"image_url"
:
...
...
@@ -715,8 +705,7 @@ MM_PARSER_MAP: Dict[str,
def
_parse_chat_message_content_mm_part
(
part
:
ChatCompletionContentPartParam
)
->
Tuple
[
str
,
Union
[
str
,
Dict
[
str
,
str
]]]:
part
:
ChatCompletionContentPartParam
)
->
tuple
[
str
,
_ContentPart
]:
"""
Parses a given multi-modal content part based on its type.
...
...
@@ -783,7 +772,7 @@ def _parse_chat_message_content_parts(
*
,
wrap_dicts
:
bool
,
)
->
List
[
ConversationMessage
]:
content
:
List
[
Union
[
str
,
Dict
[
str
,
str
]]]
=
[]
content
=
list
[
_ContentPart
]()
mm_parser
=
mm_tracker
.
create_parser
()
...
...
@@ -814,7 +803,7 @@ def _parse_chat_message_content_part(
mm_parser
:
BaseMultiModalContentParser
,
*
,
wrap_dicts
:
bool
,
)
->
Optional
[
Union
[
str
,
Dict
[
str
,
str
]]
]:
)
->
Optional
[
_ContentPart
]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
...
...
@@ -823,8 +812,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders.
"""
if
isinstance
(
part
,
str
):
# Handle plain text parts
text
=
_TextParser
(
part
)
return
text
return
part
# Handle structured dictionary parts
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
...
...
@@ -855,7 +843,7 @@ def _parse_chat_message_content_part(
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
if
part_type
==
"input_audio"
:
dict_content
=
cast
(
Dict
[
str
,
str
]
,
content
)
dict_content
=
cast
(
InputAudio
,
content
)
mm_parser
.
parse_input_audio
(
dict_content
)
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
...
...
@@ -1000,14 +988,14 @@ def apply_mistral_chat_template(
**
kwargs
:
Any
,
)
->
List
[
int
]:
if
chat_template
is
not
None
:
print_
warning_once
(
logger
.
warning_once
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
print_
warning_once
(
logger
.
warning_once
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
print_
warning_once
(
logger
.
warning_once
(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
...
...
vllm/entrypoints/llm.py
View file @
afd0da21
import
itertools
import
warnings
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
from
typing
import
(
Any
,
Callable
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
import
cloudpickle
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
typing_extensions
import
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm
import
envs
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
...
@@ -21,7 +24,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
parse_chat_messages
,
resolve_chat_template_content_format
)
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
is_token_prompt
,
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding.guided_fields
import
(
...
...
@@ -41,6 +44,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
...
...
@@ -186,6 +191,13 @@ class LLM:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
if
"worker_cls"
in
kwargs
:
worker_cls
=
kwargs
[
"worker_cls"
]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if
isinstance
(
worker_cls
,
type
):
kwargs
[
"worker_cls"
]
=
cloudpickle
.
dumps
(
worker_cls
)
if
compilation_config
is
not
None
:
if
isinstance
(
compilation_config
,
(
int
,
dict
)):
compilation_config_instance
=
CompilationConfig
.
from_cli
(
...
...
@@ -225,18 +237,11 @@ class LLM:
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self
.
engine_class
=
self
.
get_engine_class
()
# TODO(rob): enable mp by default (issue with fork vs spawn)
self
.
llm_engine
=
self
.
engine_class
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
def
__del__
(
self
):
if
hasattr
(
self
,
'llm_engine'
)
and
self
.
llm_engine
and
hasattr
(
self
.
llm_engine
,
"shutdown"
):
self
.
llm_engine
.
shutdown
()
@
staticmethod
def
get_engine_class
()
->
Type
[
LLMEngine
]:
if
envs
.
VLLM_USE_V1
:
...
...
@@ -462,9 +467,47 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
List
[
_R
]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
def
beam_search
(
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]
]],
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]],
params
:
BeamSearchParams
,
)
->
List
[
BeamSearchOutput
]:
"""
...
...
@@ -500,8 +543,10 @@ class LLM:
instances
:
List
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
prompt_tokens
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
if
is_token_prompt
(
prompt
):
prompt_tokens
=
prompt
[
"prompt_token_ids"
]
else
:
prompt_tokens
=
tokenizer
.
encode
(
prompt
[
"prompt"
])
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
for
_
in
range
(
max_tokens
):
...
...
@@ -952,6 +997,107 @@ class LLM:
return
[
ClassificationRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_embedding_score
(
self
,
tokenizer
:
AnyTokenizer
,
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
ScoringRequestOutput
]:
encoded_output
=
self
.
encode
(
text_1
+
text_2
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
encoded_output_1
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_2
=
encoded_output
[
len
(
text_1
):]
if
len
(
encoded_output_1
)
==
1
:
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
output_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
encoded_output_1
,
encoded_output_2
)]
scores
=
[]
scorer
=
torch
.
nn
.
CosineSimilarity
(
0
)
for
embed_1
,
embed_2
in
output_pairs
:
pair_score
=
scorer
(
embed_1
.
outputs
.
data
,
embed_2
.
outputs
.
data
)
if
(
pad_token_id
:
=
getattr
(
tokenizer
,
"pad_token_id"
,
None
))
is
not
None
:
tokens
=
embed_1
.
prompt_token_ids
+
[
pad_token_id
]
+
embed_2
.
prompt_token_ids
else
:
tokens
=
embed_1
.
prompt_token_ids
+
embed_2
.
prompt_token_ids
scores
.
append
(
PoolingRequestOutput
(
request_id
=
f
"
{
embed_1
.
request_id
}
_
{
embed_2
.
request_id
}
"
,
outputs
=
pair_score
,
prompt_token_ids
=
tokens
,
finished
=
True
))
items
=
self
.
engine_class
.
validate_outputs
(
scores
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_cross_encoding_score
(
self
,
tokenizer
:
Union
[
AnyTokenizer
],
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
ScoringRequestOutput
]:
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Score API is only enabled for `--task embed or score`"
)
if
len
(
text_1
)
==
1
:
text_1
=
text_1
*
len
(
text_2
)
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
text_1
,
text_2
)]
pooling_params
=
PoolingParams
()
tokenization_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
truncate_prompt_tokens
is
not
None
:
tokenization_kwargs
[
"truncation"
]
=
True
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
parsed_prompts
=
[]
for
q
,
t
in
input_pairs
:
prompt_inputs
=
tokenizer
(
text
=
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
score
(
self
,
text_1
:
Union
[
SingletonPrompt
,
Sequence
[
SingletonPrompt
]],
...
...
@@ -1003,25 +1149,20 @@ class LLM:
raise
ValueError
(
" "
.
join
(
messages
))
if
not
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
raise
ValueError
(
"Your model does not support cross encoding"
)
if
self
.
llm_engine
.
model_config
.
task
!=
"score"
:
raise
ValueError
(
"Score API is only enabled for `--task score`"
)
tokenizer
=
self
.
llm_engine
.
get_tokenizer
()
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
self
.
llm_engine
.
model_config
.
task
not
in
(
"embed"
,
"score"
):
raise
ValueError
(
"
MistralTokenizer not support
ed
f
or
cross-encoding
"
)
"
Score API is only enabled for `--task emb
ed or
--task score`
"
)
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer
=
self
.
llm_engine
.
get_tokenizer
()
def
ensure_str
(
prompt
:
SingletonPrompt
):
if
isinstance
(
prompt
,
dict
):
if
"multi_modal_data"
in
prompt
:
raise
ValueError
(
"Multi-modal prompt is not "
"supported for
cross en
co
d
ing"
)
"supported for
s
co
r
ing"
)
elif
"prompt_token_ids"
in
prompt
:
prompt
=
tokenizer
.
decode
(
cast
(
TokensPrompt
,
prompt
)[
"prompt_token_ids"
])
...
...
@@ -1047,40 +1188,15 @@ class LLM:
if
len
(
text_2
)
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
if
len
(
text_1
)
==
1
:
text_1
=
text_1
*
len
(
text_2
)
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
text_1
,
text_2
)]
pooling_params
=
PoolingParams
()
tokenization_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
truncate_prompt_tokens
is
not
None
:
tokenization_kwargs
[
"truncation"
]
=
True
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
parsed_prompts
=
[]
for
q
,
t
in
input_pairs
:
prompt_inputs
=
tokenizer
(
text
=
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
return
self
.
_cross_encoding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
else
:
return
self
.
_embedding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
...
...
@@ -1088,6 +1204,36 @@ class LLM:
def
stop_profile
(
self
)
->
None
:
self
.
llm_engine
.
stop_profile
()
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
llm_engine
.
reset_prefix_cache
()
def
sleep
(
self
,
level
:
int
=
1
):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
"""
self
.
reset_prefix_cache
()
self
.
llm_engine
.
sleep
(
level
=
level
)
def
wake_up
(
self
):
"""
Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details."""
self
.
llm_engine
.
wake_up
()
# LEGACY
def
_convert_v1_inputs
(
self
,
...
...
Prev
1
…
26
27
28
29
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment