Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
587
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
372 additions
and
163 deletions
+372
-163
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+53
-9
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+32
-11
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+2
-2
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+2
-2
vllm/engine/protocol.py
vllm/engine/protocol.py
+10
-0
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+69
-81
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+204
-58
No files found.
Too many changes to show.
To preserve performance only
587 of 587+
files are displayed.
Plain diff
Email patch
vllm/engine/multiprocessing/client.py
View file @
afd0da21
...
@@ -25,7 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -25,7 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCAdapterLoadedResponse
,
RPCError
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
RPCUProfileRequest
)
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
...
@@ -240,22 +243,34 @@ class MQLLMEngineClient(EngineClient):
...
@@ -240,22 +243,34 @@ class MQLLMEngineClient(EngineClient):
queue
=
self
.
output_queues
.
get
(
request_id
)
queue
=
self
.
output_queues
.
get
(
request_id
)
if
queue
is
not
None
:
if
queue
is
not
None
:
queue
.
put_nowait
(
exception
)
queue
.
put_nowait
(
exception
)
# Put each output into the appropriate queue.
elif
isinstance
(
request_outputs
,
RPCAdapterLoadedResponse
):
self
.
_add_output
(
request_outputs
)
else
:
else
:
# Put each output into the appropriate steam.
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
queue
=
self
.
output_queues
.
get
(
self
.
_add_output
(
request_output
)
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient output handler."
)
logger
.
debug
(
"Shutting down MQLLMEngineClient output handler."
)
def
_add_output
(
self
,
request_output
:
Union
[
RequestOutput
,
RPCAdapterLoadedResponse
]):
queue
=
self
.
output_queues
.
get
(
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
async
def
setup
(
self
):
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
"""Setup the client before it starts sending server requests."""
# Start output_loop
# Start output_loop
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
if
self
.
output_loop
is
None
:
# only generate once to avoid multiple concurrent output_loops
# this will lead to race conditions and wrong orders of tokens
# returned by the engine
# setup will be called multiple times during the startup of
# the engine
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
with
self
.
get_data_socket
()
as
socket
:
with
self
.
get_data_socket
()
as
socket
:
# Wait until server is ready.
# Wait until server is ready.
...
@@ -264,8 +279,9 @@ class MQLLMEngineClient(EngineClient):
...
@@ -264,8 +279,9 @@ class MQLLMEngineClient(EngineClient):
self
.
tracing_flag
=
response
.
tracing_enabled
self
.
tracing_flag
=
response
.
tracing_enabled
# Start health_loop.
# Start health_loop.
self
.
health_loop
=
asyncio
.
create_task
(
if
self
.
health_loop
is
None
:
self
.
run_heartbeat_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
self
.
health_loop
=
asyncio
.
create_task
(
self
.
run_heartbeat_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
def
close
(
self
):
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
"""Destroy the ZeroMQ Context."""
...
@@ -659,3 +675,31 @@ class MQLLMEngineClient(EngineClient):
...
@@ -659,3 +675,31 @@ class MQLLMEngineClient(EngineClient):
await
self
.
_send_one_way_rpc_request
(
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
async
def
reset_prefix_cache
(
self
)
->
None
:
"""Reset the prefix cache"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCResetPrefixCacheRequest
.
RESET_PREFIX_CACHE
,
socket
=
self
.
input_socket
)
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request
=
RPCLoadAdapterRequest
(
lora_request
)
# Create output queue for this requests.
queue
:
asyncio
.
Queue
[
Union
[
None
,
BaseException
]]
=
asyncio
.
Queue
()
self
.
output_queues
[
request
.
request_id
]
=
queue
# Send the request
request_bytes
=
pickle
.
dumps
(
request
)
await
self
.
input_socket
.
send_multipart
((
request_bytes
,
),
copy
=
False
)
# Wait for the response
request_output
=
await
queue
.
get
()
self
.
output_queues
.
pop
(
request
.
request_id
)
# Raise on error, otherwise happily return None
if
isinstance
(
request_output
,
BaseException
):
raise
request_output
vllm/engine/multiprocessing/engine.py
View file @
afd0da21
...
@@ -14,11 +14,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -14,11 +14,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCAdapterLoadedResponse
,
RPCError
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
RPCUProfileRequest
)
# yapf: enable
# yapf: enable
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -234,6 +236,10 @@ class MQLLMEngine:
...
@@ -234,6 +236,10 @@ class MQLLMEngine:
self
.
start_profile
()
self
.
start_profile
()
else
:
else
:
self
.
stop_profile
()
self
.
stop_profile
()
elif
isinstance
(
request
,
RPCLoadAdapterRequest
):
self
.
_handle_load_adapter_request
(
request
)
elif
isinstance
(
request
,
RPCResetPrefixCacheRequest
):
self
.
reset_prefix_cache
()
else
:
else
:
raise
ValueError
(
"Unknown RPCRequest Type: "
raise
ValueError
(
"Unknown RPCRequest Type: "
f
"
{
type
(
request
)
}
"
)
f
"
{
type
(
request
)
}
"
)
...
@@ -284,6 +290,20 @@ class MQLLMEngine:
...
@@ -284,6 +290,20 @@ class MQLLMEngine:
if
self
.
log_requests
:
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
def
_handle_load_adapter_request
(
self
,
request
:
RPCLoadAdapterRequest
):
try
:
self
.
engine
.
add_lora
(
request
.
lora_request
)
except
BaseException
as
e
:
# Send back an error if the adater fails to load
rpc_err
=
RPCError
(
request_id
=
request
.
request_id
,
is_engine_errored
=
False
,
exception
=
e
)
self
.
_send_outputs
(
rpc_err
)
return
# Otherwise, send back the successful load message
self
.
_send_outputs
(
RPCAdapterLoadedResponse
(
request_id
=
request
.
request_id
))
def
_health_check
(
self
):
def
_health_check
(
self
):
# Send unhealthy if engine has already errored
# Send unhealthy if engine has already errored
if
self
.
_errored_with
is
not
None
:
if
self
.
_errored_with
is
not
None
:
...
@@ -296,7 +316,11 @@ class MQLLMEngine:
...
@@ -296,7 +316,11 @@ class MQLLMEngine:
self
.
_send_unhealthy
(
e
)
self
.
_send_unhealthy
(
e
)
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
"""Send List of RequestOutput to RPCClient."""
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if
outputs
:
if
outputs
:
try
:
try
:
from
ray.exceptions
import
RayTaskError
from
ray.exceptions
import
RayTaskError
...
@@ -335,16 +359,13 @@ class MQLLMEngine:
...
@@ -335,16 +359,13 @@ class MQLLMEngine:
self
.
_errored_with
=
e
self
.
_errored_with
=
e
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
start_profile
()
self
.
engine
.
model_executor
.
start_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
def
stop_profile
(
self
)
->
None
:
def
stop_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
stop_profile
()
self
.
engine
.
model_executor
.
stop_profile
()
else
:
def
reset_prefix_cache
(
self
)
->
bool
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop
_pr
o
fi
le"
)
return
self
.
engine
.
reset
_pr
e
fi
x_cache
(
)
def
signal_handler
(
*
_
)
->
None
:
def
signal_handler
(
*
_
)
->
None
:
...
...
vllm/engine/output_processor/multi_step.py
View file @
afd0da21
...
@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@
staticmethod
@
staticmethod
@
functools
.
lru_cache
@
functools
.
lru_cache
def
_log_prompt_logprob_unsupported_warning_once
():
def
_log_prompt_logprob_unsupported_warning_once
():
# Reminder: Please update docs/source/
usage
/compatibility_matrix.md
# Reminder: Please update docs/source/
features
/compatibility_matrix.md
# If the feature combo become valid
# If the feature combo become valid
logger
.
warning
(
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
"Prompt logprob is not supported by multi step workers. "
...
@@ -144,7 +144,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -144,7 +144,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
new_char_count
=
0
new_char_count
=
0
if
sampling_params
.
detokenize
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
seq
,
sampling_params
)
...
...
vllm/engine/output_processor/single_step.py
View file @
afd0da21
...
@@ -102,9 +102,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -102,9 +102,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
Args:
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
output
s
: the :class:`SequenceGroupOutput` for a single scheduler step
"""
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only ha
s
1 output."
)
assert
len
(
outputs
)
==
1
,
"Single step should only ha
ve
1 output."
output
=
outputs
[
0
]
output
=
outputs
[
0
]
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
...
...
vllm/engine/protocol.py
View file @
afd0da21
...
@@ -270,3 +270,13 @@ class EngineClient(ABC):
...
@@ -270,3 +270,13 @@ class EngineClient(ABC):
async
def
stop_profile
(
self
)
->
None
:
async
def
stop_profile
(
self
)
->
None
:
"""Start profiling the engine"""
"""Start profiling the engine"""
...
...
@
abstractmethod
async
def
reset_prefix_cache
(
self
)
->
None
:
"""Reset the prefix cache"""
...
@
abstractmethod
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
...
vllm/entrypoints/chat_utils.py
View file @
afd0da21
...
@@ -3,10 +3,10 @@ import codecs
...
@@ -3,10 +3,10 @@ import codecs
import
json
import
json
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
functools
import
lru_cache
,
partial
from
functools
import
cache
,
lru_cache
,
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Iterable
,
List
,
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
import
jinja2.nodes
import
jinja2.nodes
import
transformers.utils.chat_template_utils
as
hf_chat_utils
import
transformers.utils.chat_template_utils
as
hf_chat_utils
...
@@ -23,6 +23,8 @@ from openai.types.chat import (
...
@@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam
as
OpenAIChatCompletionMessageParam
)
ChatCompletionMessageParam
as
OpenAIChatCompletionMessageParam
)
from
openai.types.chat
import
(
ChatCompletionMessageToolCallParam
,
from
openai.types.chat
import
(
ChatCompletionMessageToolCallParam
,
ChatCompletionToolMessageParam
)
ChatCompletionToolMessageParam
)
from
openai.types.chat.chat_completion_content_part_input_audio_param
import
(
InputAudio
)
# yapf: enable
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
# pydantic needs the TypedDict from typing_extensions
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
...
@@ -31,13 +33,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
...
@@ -31,13 +33,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
from
vllm.multimodal.utils
import
MediaConnector
async_get_and_parse_image
,
async_get_and_parse_video
,
get_and_parse_audio
,
get_and_parse_image
,
get_and_parse_video
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -368,16 +365,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -368,16 +365,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
_tokenizer
=
tokenizer
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
if
model_config
.
multimodal_config
else
{})
self
.
_consumed_items
=
{
k
:
0
for
k
in
self
.
_allowed_items
}
self
.
_items
:
List
[
_T
]
=
[]
self
.
_items
_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
@
property
@
property
def
model_config
(
self
)
->
ModelConfig
:
def
model_config
(
self
)
->
ModelConfig
:
return
self
.
_model_config
return
self
.
_model_config
@
property
def
allowed_local_media_path
(
self
):
return
self
.
_model_config
.
allowed_local_media_path
@
staticmethod
@
staticmethod
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
return
tokenizer
.
decode
(
token_index
)
return
tokenizer
.
decode
(
token_index
)
...
@@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
==
"phi3_v"
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
# Workaround since this token is not defined in the tokenizer
return
f
"<|image_
{
current_count
}
|>"
return
f
"<|image_
{
current_count
}
|>"
if
model_type
==
"minicpmv"
:
if
model_type
in
(
"minicpmo"
,
"minicpmv"
)
:
return
"(<image>./</image>)"
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
,
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
,
"pixtral"
):
"pixtral"
):
...
@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
.
startswith
(
"llava"
):
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
image_token_index
)
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
,
"NVLM_D"
,
if
model_type
in
(
"chameleon"
,
"deepseek_vl_v2"
,
"internvl_chat"
,
"h2ovl_chat"
):
"NVLM_D"
,
"h2ovl_chat"
):
return
"<image>"
return
"<image>"
if
model_type
==
"mllama"
:
if
model_type
==
"mllama"
:
return
"<|image|>"
return
"<|image|>"
...
@@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
==
"qwen2_audio"
:
if
model_type
==
"qwen2_audio"
:
return
(
f
"Audio
{
current_count
}
: "
return
(
f
"Audio
{
current_count
}
: "
f
"<|audio_bos|><|AUDIO|><|audio_eos|>"
)
f
"<|audio_bos|><|AUDIO|><|audio_eos|>"
)
if
model_type
==
"minicpmo"
:
return
"(<audio>./</audio>)"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"video"
:
elif
modality
==
"video"
:
if
model_type
==
"qwen2_vl"
:
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|video_pad|><|vision_end|>"
return
"<|vision_start|><|video_pad|><|vision_end|>"
if
model_type
in
(
"minicpmo"
,
"minicpmv"
):
return
"(<video>./</video>)"
if
model_type
.
startswith
(
"llava"
):
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
video_token_index
)
hf_config
.
video_token_index
)
...
@@ -435,38 +439,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -435,38 +439,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else
:
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
@
staticmethod
def
_combine
(
items
:
List
[
MultiModalDataDict
])
->
MultiModalDataDict
:
mm_lists
:
Mapping
[
str
,
List
[
object
]]
=
defaultdict
(
list
)
# Merge all the multi-modal items
for
single_mm_data
in
items
:
for
mm_key
,
mm_item
in
single_mm_data
.
items
():
if
isinstance
(
mm_item
,
list
):
mm_lists
[
mm_key
].
extend
(
mm_item
)
else
:
mm_lists
[
mm_key
].
append
(
mm_item
)
# Unpack any single item lists for models that don't expect multiple.
return
{
mm_key
:
mm_list
[
0
]
if
len
(
mm_list
)
==
1
else
mm_list
for
mm_key
,
mm_list
in
mm_lists
.
items
()
}
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
)
->
Optional
[
str
]:
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
)
->
Optional
[
str
]:
"""
"""
Add a multi-modal item to the current prompt and returns the
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
placeholder string to use, if any.
"""
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
current_count
=
self
.
_
consumed_items
.
get
(
modality
,
0
)
+
1
current_count
=
len
(
self
.
_
items_by_modality
[
modality
]
)
+
1
if
current_count
>
allowed_count
:
if
current_count
>
allowed_count
:
raise
ValueError
(
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
"one request."
)
self
.
_consumed_items
[
modality
]
=
current_count
self
.
_items_by_modality
[
modality
].
append
(
item
)
self
.
_items
.
append
(
item
)
return
self
.
_placeholder_str
(
modality
,
current_count
)
return
self
.
_placeholder_str
(
modality
,
current_count
)
...
@@ -475,22 +460,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -475,22 +460,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise
NotImplementedError
raise
NotImplementedError
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
MultiModalDataDi
ct
]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
obje
ct
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
return
self
.
_combine
(
self
.
_items
)
if
self
.
_items
else
None
if
self
.
_items_by_modality
:
return
dict
(
self
.
_items_by_modality
)
return
None
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
MultiModalContentParser
(
self
)
return
MultiModalContentParser
(
self
)
class
AsyncMultiModalItemTracker
(
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
BaseMultiModalItemTracker
[
Awaitable
[
MultiModalDataDict
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items
:
if
self
.
_items_by_modality
:
items
=
await
asyncio
.
gather
(
*
self
.
_items
)
return
{
return
self
.
_combine
(
items
)
modality
:
await
asyncio
.
gather
(
*
items
)
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
return
None
return
None
...
@@ -522,7 +511,7 @@ class BaseMultiModalContentParser(ABC):
...
@@ -522,7 +511,7 @@ class BaseMultiModalContentParser(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
]
)
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
@@ -537,31 +526,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
...
@@ -537,31 +526,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self
.
_tracker
=
tracker
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image
=
get_and_parse_image
(
image_url
,
image
=
self
.
_connector
.
fetch_image
(
image_url
)
allowed_local_media_path
=
self
.
_tracker
.
_model_config
.
allowed_local_media_path
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
get_and_parse
_audio
(
audio_url
)
audio
=
self
.
_connector
.
fetch
_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
])
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
input_audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_data
=
input_audio
.
get
(
"data"
,
""
)
input_audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
input_audio_format
}
;base64,
{
input_audio_data
}
"
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
audio
=
get_and_parse_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
return
self
.
parse_audio
(
audio_url
)
self
.
_add_placeholder
(
placeholder
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
video
=
get_and_parse
_video
(
video_url
)
video
=
self
.
_connector
.
fetch
_video
(
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
...
@@ -573,33 +562,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
...
@@ -573,33 +562,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super
().
__init__
()
super
().
__init__
()
self
.
_tracker
=
tracker
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image_coro
=
async_get_and_parse_image
(
image_coro
=
self
.
_connector
.
fetch_image_async
(
image_url
)
image_url
,
allowed_local_media_path
=
self
.
_tracker
.
_model_config
.
allowed_local_media_path
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
])
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
input_audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_data
=
input_audio
.
get
(
"data"
,
""
)
input_audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
input_audio_format
}
;base64,
{
input_audio_data
}
"
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
return
self
.
parse_audio
(
audio_url
)
self
.
_add_placeholder
(
placeholder
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
video
=
async_get_and_parse_video
(
video_url
)
video
=
self
.
_connector
.
fetch_video_async
(
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
...
@@ -695,10 +682,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
...
@@ -695,10 +682,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_VideoParser
=
partial
(
cast
,
ChatCompletionContentPartVideoParam
)
_VideoParser
=
partial
(
cast
,
ChatCompletionContentPartVideoParam
)
_ContentPart
:
TypeAlias
=
Union
[
str
,
Dict
[
str
,
str
],
InputAudio
]
# Define a mapping from part types to their corresponding parsing functions.
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP
:
Dict
[
str
,
MM_PARSER_MAP
:
Dict
[
Callable
[[
ChatCompletionContentPartParam
],
str
,
Union
[
str
,
Dict
[
str
,
str
]]]]
=
{
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
"text"
:
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
"image_url"
:
"image_url"
:
...
@@ -715,8 +705,7 @@ MM_PARSER_MAP: Dict[str,
...
@@ -715,8 +705,7 @@ MM_PARSER_MAP: Dict[str,
def
_parse_chat_message_content_mm_part
(
def
_parse_chat_message_content_mm_part
(
part
:
ChatCompletionContentPartParam
)
->
Tuple
[
str
,
part
:
ChatCompletionContentPartParam
)
->
tuple
[
str
,
_ContentPart
]:
Union
[
str
,
Dict
[
str
,
str
]]]:
"""
"""
Parses a given multi-modal content part based on its type.
Parses a given multi-modal content part based on its type.
...
@@ -783,7 +772,7 @@ def _parse_chat_message_content_parts(
...
@@ -783,7 +772,7 @@ def _parse_chat_message_content_parts(
*
,
*
,
wrap_dicts
:
bool
,
wrap_dicts
:
bool
,
)
->
List
[
ConversationMessage
]:
)
->
List
[
ConversationMessage
]:
content
:
List
[
Union
[
str
,
Dict
[
str
,
str
]]]
=
[]
content
=
list
[
_ContentPart
]()
mm_parser
=
mm_tracker
.
create_parser
()
mm_parser
=
mm_tracker
.
create_parser
()
...
@@ -814,7 +803,7 @@ def _parse_chat_message_content_part(
...
@@ -814,7 +803,7 @@ def _parse_chat_message_content_part(
mm_parser
:
BaseMultiModalContentParser
,
mm_parser
:
BaseMultiModalContentParser
,
*
,
*
,
wrap_dicts
:
bool
,
wrap_dicts
:
bool
,
)
->
Optional
[
Union
[
str
,
Dict
[
str
,
str
]]
]:
)
->
Optional
[
_ContentPart
]:
"""Parses a single part of a conversation. If wrap_dicts is True,
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
...
@@ -823,8 +812,7 @@ def _parse_chat_message_content_part(
...
@@ -823,8 +812,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders.
with multimodal placeholders.
"""
"""
if
isinstance
(
part
,
str
):
# Handle plain text parts
if
isinstance
(
part
,
str
):
# Handle plain text parts
text
=
_TextParser
(
part
)
return
part
return
text
# Handle structured dictionary parts
# Handle structured dictionary parts
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
...
@@ -855,7 +843,7 @@ def _parse_chat_message_content_part(
...
@@ -855,7 +843,7 @@ def _parse_chat_message_content_part(
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
if
part_type
==
"input_audio"
:
if
part_type
==
"input_audio"
:
dict_content
=
cast
(
Dict
[
str
,
str
]
,
content
)
dict_content
=
cast
(
InputAudio
,
content
)
mm_parser
.
parse_input_audio
(
dict_content
)
mm_parser
.
parse_input_audio
(
dict_content
)
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
...
@@ -1000,14 +988,14 @@ def apply_mistral_chat_template(
...
@@ -1000,14 +988,14 @@ def apply_mistral_chat_template(
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
List
[
int
]:
)
->
List
[
int
]:
if
chat_template
is
not
None
:
if
chat_template
is
not
None
:
print_
warning_once
(
logger
.
warning_once
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
if
"add_generation_prompt"
in
kwargs
:
print_
warning_once
(
logger
.
warning_once
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
if
"continue_final_message"
in
kwargs
:
print_
warning_once
(
logger
.
warning_once
(
"'continue_final_message' is not supported for mistral tokenizer, "
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
...
...
vllm/entrypoints/llm.py
View file @
afd0da21
import
itertools
import
itertools
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Callable
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Union
,
cast
,
overload
)
Tuple
,
Type
,
Union
,
cast
,
overload
)
import
cloudpickle
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing_extensions
import
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm
import
envs
from
vllm
import
envs
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
@@ -21,7 +24,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
...
@@ -21,7 +24,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
parse_chat_messages
,
parse_chat_messages
,
resolve_chat_template_content_format
)
resolve_chat_template_content_format
)
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
is_token_prompt
,
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding.guided_fields
import
(
from
vllm.model_executor.guided_decoding.guided_fields
import
(
...
@@ -41,6 +44,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
...
@@ -41,6 +44,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLM
:
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
"""An LLM for generating texts from given prompts and sampling parameters.
...
@@ -186,6 +191,13 @@ class LLM:
...
@@ -186,6 +191,13 @@ class LLM:
if
"disable_log_stats"
not
in
kwargs
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
kwargs
[
"disable_log_stats"
]
=
True
if
"worker_cls"
in
kwargs
:
worker_cls
=
kwargs
[
"worker_cls"
]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if
isinstance
(
worker_cls
,
type
):
kwargs
[
"worker_cls"
]
=
cloudpickle
.
dumps
(
worker_cls
)
if
compilation_config
is
not
None
:
if
compilation_config
is
not
None
:
if
isinstance
(
compilation_config
,
(
int
,
dict
)):
if
isinstance
(
compilation_config
,
(
int
,
dict
)):
compilation_config_instance
=
CompilationConfig
.
from_cli
(
compilation_config_instance
=
CompilationConfig
.
from_cli
(
...
@@ -225,18 +237,11 @@ class LLM:
...
@@ -225,18 +237,11 @@ class LLM:
# Logic to switch between engines is done at runtime instead of import
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
# to avoid import order issues
self
.
engine_class
=
self
.
get_engine_class
()
self
.
engine_class
=
self
.
get_engine_class
()
# TODO(rob): enable mp by default (issue with fork vs spawn)
self
.
llm_engine
=
self
.
engine_class
.
from_engine_args
(
self
.
llm_engine
=
self
.
engine_class
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
def
__del__
(
self
):
if
hasattr
(
self
,
'llm_engine'
)
and
self
.
llm_engine
and
hasattr
(
self
.
llm_engine
,
"shutdown"
):
self
.
llm_engine
.
shutdown
()
@
staticmethod
@
staticmethod
def
get_engine_class
()
->
Type
[
LLMEngine
]:
def
get_engine_class
()
->
Type
[
LLMEngine
]:
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
...
@@ -462,9 +467,47 @@ class LLM:
...
@@ -462,9 +467,47 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
List
[
_R
]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
def
beam_search
(
def
beam_search
(
self
,
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]
]],
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]],
params
:
BeamSearchParams
,
params
:
BeamSearchParams
,
)
->
List
[
BeamSearchOutput
]:
)
->
List
[
BeamSearchOutput
]:
"""
"""
...
@@ -500,8 +543,10 @@ class LLM:
...
@@ -500,8 +543,10 @@ class LLM:
instances
:
List
[
BeamSearchInstance
]
=
[]
instances
:
List
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
for
prompt
in
prompts
:
prompt_tokens
=
prompt
if
isinstance
(
if
is_token_prompt
(
prompt
):
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
prompt_tokens
=
prompt
[
"prompt_token_ids"
]
else
:
prompt_tokens
=
tokenizer
.
encode
(
prompt
[
"prompt"
])
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
for
_
in
range
(
max_tokens
):
for
_
in
range
(
max_tokens
):
...
@@ -952,6 +997,107 @@ class LLM:
...
@@ -952,6 +997,107 @@ class LLM:
return
[
ClassificationRequestOutput
.
from_base
(
item
)
for
item
in
items
]
return
[
ClassificationRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_embedding_score
(
self
,
tokenizer
:
AnyTokenizer
,
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
ScoringRequestOutput
]:
encoded_output
=
self
.
encode
(
text_1
+
text_2
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
encoded_output_1
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_2
=
encoded_output
[
len
(
text_1
):]
if
len
(
encoded_output_1
)
==
1
:
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
output_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
encoded_output_1
,
encoded_output_2
)]
scores
=
[]
scorer
=
torch
.
nn
.
CosineSimilarity
(
0
)
for
embed_1
,
embed_2
in
output_pairs
:
pair_score
=
scorer
(
embed_1
.
outputs
.
data
,
embed_2
.
outputs
.
data
)
if
(
pad_token_id
:
=
getattr
(
tokenizer
,
"pad_token_id"
,
None
))
is
not
None
:
tokens
=
embed_1
.
prompt_token_ids
+
[
pad_token_id
]
+
embed_2
.
prompt_token_ids
else
:
tokens
=
embed_1
.
prompt_token_ids
+
embed_2
.
prompt_token_ids
scores
.
append
(
PoolingRequestOutput
(
request_id
=
f
"
{
embed_1
.
request_id
}
_
{
embed_2
.
request_id
}
"
,
outputs
=
pair_score
,
prompt_token_ids
=
tokens
,
finished
=
True
))
items
=
self
.
engine_class
.
validate_outputs
(
scores
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_cross_encoding_score
(
self
,
tokenizer
:
Union
[
AnyTokenizer
],
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
ScoringRequestOutput
]:
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Score API is only enabled for `--task embed or score`"
)
if
len
(
text_1
)
==
1
:
text_1
=
text_1
*
len
(
text_2
)
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
text_1
,
text_2
)]
pooling_params
=
PoolingParams
()
tokenization_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
truncate_prompt_tokens
is
not
None
:
tokenization_kwargs
[
"truncation"
]
=
True
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
parsed_prompts
=
[]
for
q
,
t
in
input_pairs
:
prompt_inputs
=
tokenizer
(
text
=
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
score
(
def
score
(
self
,
self
,
text_1
:
Union
[
SingletonPrompt
,
Sequence
[
SingletonPrompt
]],
text_1
:
Union
[
SingletonPrompt
,
Sequence
[
SingletonPrompt
]],
...
@@ -1003,25 +1149,20 @@ class LLM:
...
@@ -1003,25 +1149,20 @@ class LLM:
raise
ValueError
(
" "
.
join
(
messages
))
raise
ValueError
(
" "
.
join
(
messages
))
if
not
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
if
self
.
llm_engine
.
model_config
.
task
not
in
(
"embed"
,
"score"
):
raise
ValueError
(
"Your model does not support cross encoding"
)
if
self
.
llm_engine
.
model_config
.
task
!=
"score"
:
raise
ValueError
(
"Score API is only enabled for `--task score`"
)
tokenizer
=
self
.
llm_engine
.
get_tokenizer
()
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
raise
ValueError
(
"
MistralTokenizer not support
ed
f
or
cross-encoding
"
)
"
Score API is only enabled for `--task emb
ed or
--task score`
"
)
# the tokenizer for models such as
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer
=
self
.
llm_engine
.
get_tokenizer
()
def
ensure_str
(
prompt
:
SingletonPrompt
):
def
ensure_str
(
prompt
:
SingletonPrompt
):
if
isinstance
(
prompt
,
dict
):
if
isinstance
(
prompt
,
dict
):
if
"multi_modal_data"
in
prompt
:
if
"multi_modal_data"
in
prompt
:
raise
ValueError
(
"Multi-modal prompt is not "
raise
ValueError
(
"Multi-modal prompt is not "
"supported for
cross en
co
d
ing"
)
"supported for
s
co
r
ing"
)
elif
"prompt_token_ids"
in
prompt
:
elif
"prompt_token_ids"
in
prompt
:
prompt
=
tokenizer
.
decode
(
prompt
=
tokenizer
.
decode
(
cast
(
TokensPrompt
,
prompt
)[
"prompt_token_ids"
])
cast
(
TokensPrompt
,
prompt
)[
"prompt_token_ids"
])
...
@@ -1047,40 +1188,15 @@ class LLM:
...
@@ -1047,40 +1188,15 @@ class LLM:
if
len
(
text_2
)
==
0
:
if
len
(
text_2
)
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
raise
ValueError
(
"At least one text_pair element must be given"
)
if
len
(
text_1
)
==
1
:
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
text_1
=
text_1
*
len
(
text_2
)
return
self
.
_cross_encoding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
text_1
,
text_2
)]
lora_request
,
pooling_params
=
PoolingParams
()
prompt_adapter_request
)
else
:
tokenization_kwargs
:
Dict
[
str
,
Any
]
=
{}
return
self
.
_embedding_score
(
tokenizer
,
text_1
,
text_2
,
if
truncate_prompt_tokens
is
not
None
:
truncate_prompt_tokens
,
use_tqdm
,
tokenization_kwargs
[
"truncation"
]
=
True
lora_request
,
prompt_adapter_request
)
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
parsed_prompts
=
[]
for
q
,
t
in
input_pairs
:
prompt_inputs
=
tokenizer
(
text
=
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
self
.
llm_engine
.
start_profile
()
...
@@ -1088,6 +1204,36 @@ class LLM:
...
@@ -1088,6 +1204,36 @@ class LLM:
def
stop_profile
(
self
)
->
None
:
def
stop_profile
(
self
)
->
None
:
self
.
llm_engine
.
stop_profile
()
self
.
llm_engine
.
stop_profile
()
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
llm_engine
.
reset_prefix_cache
()
def
sleep
(
self
,
level
:
int
=
1
):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
"""
self
.
reset_prefix_cache
()
self
.
llm_engine
.
sleep
(
level
=
level
)
def
wake_up
(
self
):
"""
Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details."""
self
.
llm_engine
.
wake_up
()
# LEGACY
# LEGACY
def
_convert_v1_inputs
(
def
_convert_v1_inputs
(
self
,
self
,
...
...
Prev
1
…
26
27
28
29
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment