Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4bdf7ac5
Unverified
Commit
4bdf7ac5
authored
Oct 09, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 09, 2025
Browse files
[Bugfix] Fix SHM cache initialization (#26427)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
dc7976dd
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
303 additions
and
351 deletions
+303
-351
tests/entrypoints/openai/test_lora_resolvers.py
tests/entrypoints/openai/test_lora_resolvers.py
+5
-3
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+29
-16
tests/entrypoints/openai/test_serving_engine.py
tests/entrypoints/openai/test_serving_engine.py
+3
-1
tests/entrypoints/openai/test_serving_models.py
tests/entrypoints/openai/test_serving_models.py
+4
-2
tests/entrypoints/openai/test_serving_responses.py
tests/entrypoints/openai/test_serving_responses.py
+9
-5
tests/test_inputs.py
tests/test_inputs.py
+5
-3
tests/v1/engine/test_processor_multi_modal_uuids.py
tests/v1/engine/test_processor_multi_modal_uuids.py
+1
-1
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+1
-1
vllm/benchmarks/throughput.py
vllm/benchmarks/throughput.py
+1
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+11
-201
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+17
-42
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+5
-18
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+3
-12
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-6
vllm/entrypoints/openai/serving_classification.py
vllm/entrypoints/openai/serving_classification.py
+0
-3
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-4
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+0
-3
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+199
-17
vllm/entrypoints/openai/serving_models.py
vllm/entrypoints/openai/serving_models.py
+6
-6
vllm/entrypoints/openai/serving_pooling.py
vllm/entrypoints/openai/serving_pooling.py
+0
-6
No files found.
tests/entrypoints/openai/test_lora_resolvers.py
View file @
4bdf7ac5
...
@@ -113,15 +113,17 @@ def mock_serving_setup():
...
@@ -113,15 +113,17 @@ def mock_serving_setup():
mock_engine
.
generate
.
reset_mock
()
mock_engine
.
generate
.
reset_mock
()
mock_engine
.
add_lora
.
reset_mock
()
mock_engine
.
add_lora
.
reset_mock
()
mock_model_config
=
MockModelConfig
()
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
models
=
OpenAIServingModels
(
models
=
OpenAIServingModels
(
engine_client
=
mock_engine
,
engine_client
=
mock_engine
,
base_model_paths
=
BASE_MODEL_PATHS
,
base_model_paths
=
BASE_MODEL_PATHS
,
model_config
=
mock_model_config
,
)
)
serving_completion
=
OpenAIServingCompletion
(
serving_completion
=
OpenAIServingCompletion
(
mock_engine
,
mock_model_config
,
models
,
request_logger
=
None
mock_engine
,
models
,
request_logger
=
None
)
)
serving_completion
.
_process_inputs
=
AsyncMock
(
serving_completion
.
_process_inputs
=
AsyncMock
(
...
...
tests/entrypoints/openai/test_serving_chat.py
View file @
4bdf7ac5
...
@@ -245,17 +245,13 @@ class MockModelConfig:
...
@@ -245,17 +245,13 @@ class MockModelConfig:
return
self
.
diff_sampling_param
or
{}
return
self
.
diff_sampling_param
or
{}
def
_build_serving_chat
(
def
_build_serving_chat
(
engine
:
AsyncLLM
)
->
OpenAIServingChat
:
engine
:
AsyncLLM
,
model_config
:
MockModelConfig
)
->
OpenAIServingChat
:
models
=
OpenAIServingModels
(
models
=
OpenAIServingModels
(
engine_client
=
engine
,
engine_client
=
engine
,
base_model_paths
=
BASE_MODEL_PATHS
,
base_model_paths
=
BASE_MODEL_PATHS
,
model_config
=
model_config
,
)
)
serving_chat
=
OpenAIServingChat
(
serving_chat
=
OpenAIServingChat
(
engine
,
engine
,
model_config
,
models
,
models
,
response_role
=
"assistant"
,
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
,
chat_template
=
CHAT_TEMPLATE
,
...
@@ -280,18 +276,17 @@ def _build_serving_chat(
...
@@ -280,18 +276,17 @@ def _build_serving_chat(
@
dataclass
@
dataclass
class
MockEngine
:
class
MockEngine
:
async
def
get_model_config
(
self
):
model_config
:
MockModelConfig
=
field
(
default_factory
=
MockModelConfig
)
return
MockModelConfig
()
processor
:
MagicMock
=
field
(
default_factory
=
MagicMock
)
io_processor
:
MagicMock
=
field
(
default_factory
=
MagicMock
)
async
def
_async_serving_chat_init
():
async
def
_async_serving_chat_init
():
engine
=
MockEngine
()
engine
=
MockEngine
()
model_config
=
await
engine
.
get_model_config
()
models
=
OpenAIServingModels
(
engine
,
model_config
,
BASE_MODEL_PATHS
)
models
=
OpenAIServingModels
(
engine
,
BASE_MODEL_PATHS
)
serving_completion
=
OpenAIServingChat
(
serving_completion
=
OpenAIServingChat
(
engine
,
engine
,
model_config
,
models
,
models
,
response_role
=
"assistant"
,
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
,
chat_template
=
CHAT_TEMPLATE
,
...
@@ -311,8 +306,11 @@ async def test_serving_chat_returns_correct_model_name():
...
@@ -311,8 +306,11 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
serving_chat
=
_build_serving_chat
(
mock_engine
,
MockModelConfig
()
)
serving_chat
=
_build_serving_chat
(
mock_engine
)
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
async
def
return_model_name
(
*
args
):
async
def
return_model_name
(
*
args
):
...
@@ -338,8 +336,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
...
@@ -338,8 +336,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
serving_chat
=
_build_serving_chat
(
mock_engine
,
MockModelConfig
()
)
serving_chat
=
_build_serving_chat
(
mock_engine
)
req
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
...
@@ -368,9 +369,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
...
@@ -368,9 +369,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
mock_model_config
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
# Initialize the serving chat
# Initialize the serving chat
serving_chat
=
_build_serving_chat
(
mock_engine
,
mock_model_config
)
serving_chat
=
_build_serving_chat
(
mock_engine
)
# Test Case 1: No max_tokens specified in request
# Test Case 1: No max_tokens specified in request
req
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
...
@@ -410,9 +414,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
...
@@ -410,9 +414,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
mock_model_config
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
# Initialize the serving chat
# Initialize the serving chat
serving_chat
=
_build_serving_chat
(
mock_engine
,
mock_model_config
)
serving_chat
=
_build_serving_chat
(
mock_engine
)
# Test case 1: No max_tokens specified, defaults to context_window
# Test case 1: No max_tokens specified, defaults to context_window
req
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
...
@@ -453,9 +460,12 @@ async def test_serving_chat_could_load_correct_generation_config():
...
@@ -453,9 +460,12 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
mock_model_config
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
# Initialize the serving chat
# Initialize the serving chat
serving_chat
=
_build_serving_chat
(
mock_engine
,
mock_model_config
)
serving_chat
=
_build_serving_chat
(
mock_engine
)
req
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
...
@@ -496,8 +506,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
...
@@ -496,8 +506,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
mock_model_config
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
serving_chat
=
_build_serving_chat
(
mock_engine
,
mock_model_config
)
serving_chat
=
_build_serving_chat
(
mock_engine
)
# Test cache_salt
# Test cache_salt
req
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
...
...
tests/entrypoints/openai/test_serving_engine.py
View file @
4bdf7ac5
...
@@ -22,10 +22,12 @@ def serving() -> OpenAIServing:
...
@@ -22,10 +22,12 @@ def serving() -> OpenAIServing:
model_config
=
Mock
(
spec
=
ModelConfig
)
model_config
=
Mock
(
spec
=
ModelConfig
)
model_config
.
max_model_len
=
32768
model_config
.
max_model_len
=
32768
models
=
Mock
(
spec
=
OpenAIServingModels
)
models
=
Mock
(
spec
=
OpenAIServingModels
)
models
.
model_config
=
model_config
models
.
processor
=
Mock
()
models
.
io_processor
=
Mock
()
serving
=
OpenAIServing
(
serving
=
OpenAIServing
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
None
,
request_logger
=
None
,
)
)
...
...
tests/entrypoints/openai/test_serving_models.py
View file @
4bdf7ac5
...
@@ -25,15 +25,17 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
...
@@ -25,15 +25,17 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
async
def
_async_serving_models_init
()
->
OpenAIServingModels
:
async
def
_async_serving_models_init
()
->
OpenAIServingModels
:
mock_model_config
=
MagicMock
(
spec
=
ModelConfig
)
mock_engine_client
=
MagicMock
(
spec
=
EngineClient
)
mock_engine_client
=
MagicMock
(
spec
=
EngineClient
)
# Set the max_model_len attribute to avoid missing attribute
# Set the max_model_len attribute to avoid missing attribute
mock_model_config
=
MagicMock
(
spec
=
ModelConfig
)
mock_model_config
.
max_model_len
=
2048
mock_model_config
.
max_model_len
=
2048
mock_engine_client
.
model_config
=
mock_model_config
mock_engine_client
.
processor
=
MagicMock
()
mock_engine_client
.
io_processor
=
MagicMock
()
serving_models
=
OpenAIServingModels
(
serving_models
=
OpenAIServingModels
(
engine_client
=
mock_engine_client
,
engine_client
=
mock_engine_client
,
base_model_paths
=
BASE_MODEL_PATHS
,
base_model_paths
=
BASE_MODEL_PATHS
,
model_config
=
mock_model_config
,
lora_modules
=
None
,
lora_modules
=
None
,
)
)
await
serving_models
.
init_static_loras
()
await
serving_models
.
init_static_loras
()
...
...
tests/entrypoints/openai/test_serving_responses.py
View file @
4bdf7ac5
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
AsyncExitStack
from
contextlib
import
AsyncExitStack
from
unittest.mock
import
AsyncMock
,
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
import
pytest_asyncio
import
pytest_asyncio
...
@@ -70,11 +70,14 @@ class TestInitializeToolSessions:
...
@@ -70,11 +70,14 @@ class TestInitializeToolSessions:
"""Create a real OpenAIServingResponses instance for testing"""
"""Create a real OpenAIServingResponses instance for testing"""
# Create minimal mocks for required dependencies
# Create minimal mocks for required dependencies
engine_client
=
MagicMock
()
engine_client
=
MagicMock
()
engine_client
.
get_model_config
=
AsyncMock
()
model_config
=
MagicMock
()
model_config
=
MagicMock
()
model_config
.
hf_config
.
model_type
=
"test"
model_config
.
hf_config
.
model_type
=
"test"
model_config
.
get_diff_sampling_param
.
return_value
=
{}
model_config
.
get_diff_sampling_param
.
return_value
=
{}
engine_client
.
model_config
=
model_config
engine_client
.
processor
=
MagicMock
()
engine_client
.
io_processor
=
MagicMock
()
models
=
MagicMock
()
models
=
MagicMock
()
...
@@ -83,7 +86,6 @@ class TestInitializeToolSessions:
...
@@ -83,7 +86,6 @@ class TestInitializeToolSessions:
# Create the actual instance
# Create the actual instance
instance
=
OpenAIServingResponses
(
instance
=
OpenAIServingResponses
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
None
,
request_logger
=
None
,
chat_template
=
None
,
chat_template
=
None
,
...
@@ -132,18 +134,20 @@ class TestValidateGeneratorInput:
...
@@ -132,18 +134,20 @@ class TestValidateGeneratorInput:
"""Create a real OpenAIServingResponses instance for testing"""
"""Create a real OpenAIServingResponses instance for testing"""
# Create minimal mocks for required dependencies
# Create minimal mocks for required dependencies
engine_client
=
MagicMock
()
engine_client
=
MagicMock
()
engine_client
.
get_model_config
=
AsyncMock
()
model_config
=
MagicMock
()
model_config
=
MagicMock
()
model_config
.
hf_config
.
model_type
=
"test"
model_config
.
hf_config
.
model_type
=
"test"
model_config
.
get_diff_sampling_param
.
return_value
=
{}
model_config
.
get_diff_sampling_param
.
return_value
=
{}
engine_client
.
model_config
=
model_config
engine_client
.
processor
=
MagicMock
()
engine_client
.
io_processor
=
MagicMock
()
models
=
MagicMock
()
models
=
MagicMock
()
# Create the actual instance
# Create the actual instance
instance
=
OpenAIServingResponses
(
instance
=
OpenAIServingResponses
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
None
,
request_logger
=
None
,
chat_template
=
None
,
chat_template
=
None
,
...
...
tests/test_inputs.py
View file @
4bdf7ac5
...
@@ -7,6 +7,7 @@ from vllm.config import ModelConfig
...
@@ -7,6 +7,7 @@ from vllm.config import ModelConfig
from
vllm.inputs
import
zip_enc_dec_prompts
from
vllm.inputs
import
zip_enc_dec_prompts
from
vllm.inputs.parse
import
parse_raw_prompts
from
vllm.inputs.parse
import
parse_raw_prompts
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.transformers_utils.tokenizer
import
init_tokenizer_from_configs
pytestmark
=
pytest
.
mark
.
cpu_test
pytestmark
=
pytest
.
mark
.
cpu_test
...
@@ -106,7 +107,8 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
...
@@ -106,7 +107,8 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
)
)
def
test_preprocessor_text_no_mm_inputs
(
model_id
,
prompt
):
def
test_preprocessor_text_no_mm_inputs
(
model_id
,
prompt
):
model_config
=
ModelConfig
(
model
=
model_id
)
model_config
=
ModelConfig
(
model
=
model_id
)
input_preprocessor
=
InputPreprocessor
(
model_config
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
)
input_preprocessor
=
InputPreprocessor
(
model_config
,
tokenizer
)
with
pytest
.
raises
(
ValueError
,
match
=
"does not support multimodal inputs"
):
with
pytest
.
raises
(
ValueError
,
match
=
"does not support multimodal inputs"
):
input_preprocessor
.
preprocess
(
prompt
)
input_preprocessor
.
preprocess
(
prompt
)
...
@@ -127,8 +129,8 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
...
@@ -127,8 +129,8 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
)
)
def
test_preprocessor_always_mm_code_path
(
model_id
,
prompt
):
def
test_preprocessor_always_mm_code_path
(
model_id
,
prompt
):
model_config
=
ModelConfig
(
model
=
model_id
)
model_config
=
ModelConfig
(
model
=
model_id
)
input_preprocessor
=
InputPreprocessor
(
model_config
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
)
tokenize
r
=
i
nput
_p
reprocessor
.
tokenizer
input_preprocesso
r
=
I
nput
P
reprocessor
(
model_config
,
tokenizer
)
# HF processor adds sep token
# HF processor adds sep token
sep_token_id
=
tokenizer
.
vocab
[
tokenizer
.
sep_token
]
sep_token_id
=
tokenizer
.
vocab
[
tokenizer
.
sep_token
]
...
...
tests/v1/engine/test_processor_multi_modal_uuids.py
View file @
4bdf7ac5
...
@@ -65,7 +65,7 @@ def _mk_processor(
...
@@ -65,7 +65,7 @@ def _mk_processor(
device_config
=
DeviceConfig
(
device
=
"cpu"
),
device_config
=
DeviceConfig
(
device
=
"cpu"
),
)
)
return
Processor
(
vllm_config
)
return
Processor
(
vllm_config
,
tokenizer
=
None
)
def
test_multi_modal_uuids_length_mismatch_raises
(
monkeypatch
):
def
test_multi_modal_uuids_length_mismatch_raises
(
monkeypatch
):
...
...
tests/v1/sample/test_logprobs.py
View file @
4bdf7ac5
...
@@ -459,7 +459,7 @@ def test_all_logprobs(example_prompts):
...
@@ -459,7 +459,7 @@ def test_all_logprobs(example_prompts):
results_logprobs_all
=
runner
.
llm
.
generate
(
results_logprobs_all
=
runner
.
llm
.
generate
(
example_prompts
,
sampling_params
=
sampling_params_logprobs_all
example_prompts
,
sampling_params
=
sampling_params_logprobs_all
)
)
vocab_size
=
runner
.
llm
.
llm_engine
.
get_
model_config
()
.
get_vocab_size
()
vocab_size
=
runner
.
llm
.
llm_engine
.
model_config
.
get_vocab_size
()
for
i
in
range
(
len
(
results_logprobs_all
)):
for
i
in
range
(
len
(
results_logprobs_all
)):
logprobs
=
results_logprobs_all
[
i
].
outputs
[
0
].
logprobs
logprobs
=
results_logprobs_all
[
i
].
outputs
[
0
].
logprobs
...
...
vllm/benchmarks/throughput.py
View file @
4bdf7ac5
...
@@ -186,7 +186,7 @@ async def run_vllm_async(
...
@@ -186,7 +186,7 @@ async def run_vllm_async(
engine_args
,
engine_args
,
disable_frontend_multiprocessing
=
disable_frontend_multiprocessing
,
disable_frontend_multiprocessing
=
disable_frontend_multiprocessing
,
)
as
llm
:
)
as
llm
:
model_config
=
await
llm
.
get_
model_config
()
model_config
=
llm
.
model_config
assert
all
(
assert
all
(
model_config
.
max_model_len
model_config
.
max_model_len
>=
(
request
.
prompt_len
+
request
.
expected_output_len
)
>=
(
request
.
prompt_len
+
request
.
expected_output_len
)
...
...
vllm/engine/protocol.py
View file @
4bdf7ac5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
AsyncGenerator
,
Iterable
,
Mapping
from
collections.abc
import
AsyncGenerator
,
Iterable
,
Mapping
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.plugins.io_processors
.interface
import
IOProcessor
from
vllm.plugins.io_processors
import
IOProcessor
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Device
,
collect_from_async_generator
,
random_uuid
from
vllm.utils
import
Device
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.processor
import
Processor
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -28,6 +25,11 @@ logger = init_logger(__name__)
...
@@ -28,6 +25,11 @@ logger = init_logger(__name__)
class
EngineClient
(
ABC
):
class
EngineClient
(
ABC
):
"""Protocol class for Clients to Engine"""
"""Protocol class for Clients to Engine"""
vllm_config
:
VllmConfig
model_config
:
ModelConfig
processor
:
Processor
io_processor
:
Optional
[
IOProcessor
]
@
property
@
property
@
abstractmethod
@
abstractmethod
def
is_running
(
self
)
->
bool
:
...
def
is_running
(
self
)
->
bool
:
...
...
@@ -61,180 +63,6 @@ class EngineClient(ABC):
...
@@ -61,180 +63,6 @@ class EngineClient(ABC):
"""Generate outputs for a request."""
"""Generate outputs for a request."""
...
...
async
def
beam_search
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
params
:
BeamSearchParams
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
include_stop_str_in_output
=
params
.
include_stop_str_in_output
preprocessor
=
await
self
.
get_input_preprocessor
()
tokenizer
=
preprocessor
.
get_tokenizer
()
eos_token_id
=
tokenizer
.
eos_token_id
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
NotImplementedError
else
:
processed_inputs
=
preprocessor
.
_prompt_to_llm_inputs
(
prompt
)
if
processed_inputs
[
"type"
]
==
"embeds"
:
raise
NotImplementedError
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
if
isinstance
(
prompt
,
str
):
prompt_text
=
prompt
prompt_token_ids
=
[]
multi_modal_data
=
None
else
:
prompt_text
=
prompt
.
get
(
"prompt"
)
prompt_token_ids
=
prompt
.
get
(
"prompt_token_ids"
,
[])
multi_modal_data
=
prompt
.
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
processed_inputs
.
get
(
"mm_processor_kwargs"
)
tokenized_length
=
len
(
prompt_token_ids
)
sort_beams_key
=
create_sort_beams_key_function
(
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
,
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
prompt_token_ids
,
cum_logprob
=
0
,
logprobs
=
[],
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
lora_request
=
lora_request
,
)
]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
,
lora_req_batch
=
zip
(
*
[
(
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
,
multi_modal_data
=
beam
.
multi_modal_data
,
mm_processor_kwargs
=
beam
.
mm_processor_kwargs
,
),
beam
.
lora_request
,
)
for
beam
in
all_beams
]
)
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
(
individual_prompt
,
lora_req
)
in
enumerate
(
zip
(
prompts_batch
,
lora_req_batch
)
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
,
lora_request
=
lora_req
,
)
)
)
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
output
=
[
x
[
0
]
for
x
in
output
]
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
if
token_id
==
eos_token_id
and
not
ignore_eos
:
completed
.
append
(
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
]
if
include_stop_str_in_output
else
current_beam
.
tokens
,
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
finish_reason
=
"stop"
,
stop_reason
=
eos_token_id
,
)
)
else
:
new_beams
.
append
(
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
lora_request
=
current_beam
.
lora_request
,
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
multi_modal_data
=
current_beam
.
multi_modal_data
,
mm_processor_kwargs
=
current_beam
.
mm_processor_kwargs
,
)
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
if
beam
.
tokens
[
-
1
]
==
eos_token_id
and
not
ignore_eos
:
# Skip the eos token in the text.
tokens
=
beam
.
tokens
[
tokenized_length
:
-
1
]
else
:
tokens
=
beam
.
tokens
[
tokenized_length
:]
beam
.
text
=
tokenizer
.
decode
(
tokens
)
yield
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt_text
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
[
tokenized_length
:],
index
=
i
,
logprobs
=
beam
.
logprobs
,
finish_reason
=
beam
.
finish_reason
if
beam
.
finish_reason
is
not
None
else
"length"
,
stop_reason
=
beam
.
stop_reason
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
None
,
)
@
abstractmethod
@
abstractmethod
def
encode
(
def
encode
(
self
,
self
,
...
@@ -259,29 +87,11 @@ class EngineClient(ABC):
...
@@ -259,29 +87,11 @@ class EngineClient(ABC):
"""
"""
...
...
@
abstractmethod
async
def
get_vllm_config
(
self
)
->
VllmConfig
:
"""Get the vllm configuration of the vLLM engine."""
...
@
abstractmethod
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
...
@
abstractmethod
async
def
get_input_preprocessor
(
self
)
->
InputPreprocessor
:
"""Get the input processor of the vLLM engine."""
...
@
abstractmethod
@
abstractmethod
async
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
async
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
"""Get the tokenizer"""
"""Get the tokenizer"""
...
...
async
def
get_io_processor
(
self
)
->
IOProcessor
:
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
async
def
is_tracing_enabled
(
self
)
->
bool
:
...
async
def
is_tracing_enabled
(
self
)
->
bool
:
...
...
...
vllm/entrypoints/llm.py
View file @
4bdf7ac5
...
@@ -66,7 +66,6 @@ from vllm.outputs import (
...
@@ -66,7 +66,6 @@ from vllm.outputs import (
RequestOutput
,
RequestOutput
,
ScoringRequestOutput
,
ScoringRequestOutput
,
)
)
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
BeamSearchParams
,
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
RequestOutputKind
,
SamplingParams
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
...
@@ -79,7 +78,6 @@ from vllm.usage.usage_lib import UsageContext
...
@@ -79,7 +78,6 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.utils
import
Counter
,
Device
,
as_iter
,
is_list_of
from
vllm.utils
import
Counter
,
Device
,
as_iter
,
is_list_of
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -335,21 +333,13 @@ class LLM:
...
@@ -335,21 +333,13 @@ class LLM:
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
self
.
default_sampling_params
:
Union
[
dict
[
str
,
Any
],
None
]
=
None
self
.
default_sampling_params
:
Union
[
dict
[
str
,
Any
],
None
]
=
None
supported_tasks
=
self
.
llm_engine
.
get_supported_tasks
()
# type: ignore
supported_tasks
=
self
.
llm_engine
.
get_supported_tasks
()
logger
.
info
(
"Supported tasks: %s"
,
supported_tasks
)
logger
.
info
(
"Supported_tasks: %s"
,
supported_tasks
)
self
.
supported_tasks
=
supported_tasks
self
.
supported_tasks
=
supported_tasks
# Load the Input/Output processor plugin if any
self
.
model_config
=
self
.
llm_engine
.
model_config
io_processor_plugin
=
self
.
llm_engine
.
model_config
.
io_processor_plugin
self
.
processor
=
self
.
llm_engine
.
processor
self
.
io_processor
=
get_io_processor
(
self
.
io_processor
=
self
.
llm_engine
.
io_processor
self
.
llm_engine
.
vllm_config
,
io_processor_plugin
)
@
property
def
model_config
(
self
):
return
self
.
llm_engine
.
model_config
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
return
self
.
llm_engine
.
get_tokenizer
()
return
self
.
llm_engine
.
get_tokenizer
()
...
@@ -364,18 +354,9 @@ class LLM:
...
@@ -364,18 +354,9 @@ class LLM:
else
:
else
:
self
.
llm_engine
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
self
.
llm_engine
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
def
_get_processor
(
self
)
->
Processor
:
if
not
hasattr
(
self
,
"_processor"
):
vllm_config
=
self
.
llm_engine
.
vllm_config
self
.
_processor
=
Processor
(
vllm_config
)
return
self
.
_processor
def
get_default_sampling_params
(
self
)
->
SamplingParams
:
def
get_default_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
default_sampling_params
is
None
:
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
(
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
self
.
llm_engine
.
model_config
.
get_diff_sampling_param
()
)
if
self
.
default_sampling_params
:
if
self
.
default_sampling_params
:
return
SamplingParams
.
from_optional
(
**
self
.
default_sampling_params
)
return
SamplingParams
.
from_optional
(
**
self
.
default_sampling_params
)
return
SamplingParams
()
return
SamplingParams
()
...
@@ -423,7 +404,7 @@ class LLM:
...
@@ -423,7 +404,7 @@ class LLM:
considered legacy and may be deprecated in the future. You should
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
instead pass them via the `inputs` parameter.
"""
"""
model_config
=
self
.
llm_engine
.
model_config
model_config
=
self
.
model_config
runner_type
=
model_config
.
runner_type
runner_type
=
model_config
.
runner_type
if
runner_type
!=
"generate"
:
if
runner_type
!=
"generate"
:
raise
ValueError
(
raise
ValueError
(
...
@@ -463,7 +444,7 @@ class LLM:
...
@@ -463,7 +444,7 @@ class LLM:
# isn't multimodal, leave the lora as is.
# isn't multimodal, leave the lora as is.
if
(
if
(
lora_config
is
None
lora_config
is
None
or
not
self
.
llm_engine
.
model_config
.
is_multimodal_model
or
not
self
.
model_config
.
is_multimodal_model
or
(
lora_config
and
lora_config
.
default_mm_loras
is
None
)
or
(
lora_config
and
lora_config
.
default_mm_loras
is
None
)
):
):
return
lora_request
return
lora_request
...
@@ -495,15 +476,13 @@ class LLM:
...
@@ -495,15 +476,13 @@ class LLM:
if
(
if
(
not
default_mm_loras
not
default_mm_loras
or
not
isinstance
(
prompt
,
dict
)
or
not
isinstance
(
prompt
,
dict
)
or
"multi_modal_data"
not
in
prompt
or
not
(
mm_data
:
=
prompt
.
get
(
"multi_modal_data"
)
or
{})
):
):
return
lora_request
return
lora_request
prompt
=
cast
(
Union
[
TextPrompt
,
TokensPrompt
],
prompt
)
intersection
=
set
(
mm_data
.
keys
()
# type: ignore
intersection
=
set
(
prompt
[
"multi_modal_data"
].
keys
()).
intersection
(
).
intersection
(
default_mm_loras
.
keys
())
default_mm_loras
.
keys
()
)
if
not
intersection
:
if
not
intersection
:
return
lora_request
return
lora_request
if
len
(
intersection
)
>
1
:
if
len
(
intersection
)
>
1
:
...
@@ -819,7 +798,7 @@ class LLM:
...
@@ -819,7 +798,7 @@ class LLM:
list_of_messages
=
[
cast
(
list
[
ChatCompletionMessageParam
],
messages
)]
list_of_messages
=
[
cast
(
list
[
ChatCompletionMessageParam
],
messages
)]
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_
model_config
()
model_config
=
self
.
model_config
resolved_content_format
=
resolve_chat_template_content_format
(
resolved_content_format
=
resolve_chat_template_content_format
(
chat_template
,
chat_template
,
tools
,
tools
,
...
@@ -1031,7 +1010,7 @@ class LLM:
...
@@ -1031,7 +1010,7 @@ class LLM:
pooling_task
,
pooling_task
,
)
)
model_config
=
self
.
llm_engine
.
model_config
model_config
=
self
.
model_config
runner_type
=
model_config
.
runner_type
runner_type
=
model_config
.
runner_type
if
runner_type
!=
"pooling"
:
if
runner_type
!=
"pooling"
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1276,7 +1255,7 @@ class LLM:
...
@@ -1276,7 +1255,7 @@ class LLM:
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
list
[
ScoringRequestOutput
]:
)
->
list
[
ScoringRequestOutput
]:
model_config
=
self
.
llm_engine
.
model_config
model_config
=
self
.
model_config
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Score API is not supported for Mistral tokenizer"
)
raise
ValueError
(
"Score API is not supported for Mistral tokenizer"
)
...
@@ -1287,7 +1266,6 @@ class LLM:
...
@@ -1287,7 +1266,6 @@ class LLM:
if
pooling_params
is
None
:
if
pooling_params
is
None
:
pooling_params
=
PoolingParams
(
task
=
"score"
)
pooling_params
=
PoolingParams
(
task
=
"score"
)
model_config
=
self
.
llm_engine
.
model_config
pooling_params
.
verify
(
"score"
,
model_config
)
pooling_params
.
verify
(
"score"
,
model_config
)
pooling_params_list
=
list
[
PoolingParams
]()
pooling_params_list
=
list
[
PoolingParams
]()
...
@@ -1301,8 +1279,6 @@ class LLM:
...
@@ -1301,8 +1279,6 @@ class LLM:
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
model_config
=
self
.
llm_engine
.
model_config
for
q
,
d
in
input_pairs
:
for
q
,
d
in
input_pairs
:
_
,
engine_prompt
=
get_score_prompt
(
_
,
engine_prompt
=
get_score_prompt
(
model_config
=
model_config
,
model_config
=
model_config
,
...
@@ -1380,7 +1356,7 @@ class LLM:
...
@@ -1380,7 +1356,7 @@ class LLM:
A list of `ScoringRequestOutput` objects containing the
A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts.
generated scores in the same order as the input prompts.
"""
"""
model_config
=
self
.
llm_engine
.
model_config
model_config
=
self
.
model_config
runner_type
=
model_config
.
runner_type
runner_type
=
model_config
.
runner_type
if
runner_type
!=
"pooling"
:
if
runner_type
!=
"pooling"
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1658,8 +1634,7 @@ class LLM:
...
@@ -1658,8 +1634,7 @@ class LLM:
tokenization_kwargs
,
tokenization_kwargs
,
)
)
processor
=
self
.
_get_processor
()
engine_request
=
self
.
processor
.
process_inputs
(
engine_request
=
processor
.
process_inputs
(
request_id
,
request_id
,
engine_prompt
,
engine_prompt
,
params
,
params
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
4bdf7ac5
...
@@ -1601,10 +1601,11 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -1601,10 +1601,11 @@ def build_app(args: Namespace) -> FastAPI:
async
def
init_app_state
(
async
def
init_app_state
(
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
vllm_config
:
VllmConfig
,
state
:
State
,
state
:
State
,
args
:
Namespace
,
args
:
Namespace
,
)
->
None
:
)
->
None
:
vllm_config
=
engine_client
.
vllm_config
if
args
.
served_model_name
is
not
None
:
if
args
.
served_model_name
is
not
None
:
served_model_names
=
args
.
served_model_name
served_model_names
=
args
.
served_model_name
else
:
else
:
...
@@ -1622,11 +1623,9 @@ async def init_app_state(
...
@@ -1622,11 +1623,9 @@ async def init_app_state(
state
.
engine_client
=
engine_client
state
.
engine_client
=
engine_client
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
vllm_config
=
vllm_config
state
.
vllm_config
=
vllm_config
model_config
=
vllm_config
.
model_config
supported_tasks
=
await
engine_client
.
get_supported_tasks
()
supported_tasks
=
await
engine_client
.
get_supported_tasks
()
logger
.
info
(
"Supported tasks: %s"
,
supported_tasks
)
logger
.
info
(
"Supported_tasks: %s"
,
supported_tasks
)
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
if
resolved_chat_template
is
not
None
:
if
resolved_chat_template
is
not
None
:
...
@@ -1688,7 +1687,6 @@ async def init_app_state(
...
@@ -1688,7 +1687,6 @@ async def init_app_state(
state
.
openai_serving_models
=
OpenAIServingModels
(
state
.
openai_serving_models
=
OpenAIServingModels
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
base_model_paths
=
base_model_paths
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
)
)
...
@@ -1696,7 +1694,6 @@ async def init_app_state(
...
@@ -1696,7 +1694,6 @@ async def init_app_state(
state
.
openai_serving_responses
=
(
state
.
openai_serving_responses
=
(
OpenAIServingResponses
(
OpenAIServingResponses
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
...
@@ -1717,7 +1714,6 @@ async def init_app_state(
...
@@ -1717,7 +1714,6 @@ async def init_app_state(
state
.
openai_serving_chat
=
(
state
.
openai_serving_chat
=
(
OpenAIServingChat
(
OpenAIServingChat
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
args
.
response_role
,
args
.
response_role
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
...
@@ -1740,7 +1736,6 @@ async def init_app_state(
...
@@ -1740,7 +1736,6 @@ async def init_app_state(
state
.
openai_serving_completion
=
(
state
.
openai_serving_completion
=
(
OpenAIServingCompletion
(
OpenAIServingCompletion
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
...
@@ -1754,7 +1749,6 @@ async def init_app_state(
...
@@ -1754,7 +1749,6 @@ async def init_app_state(
state
.
openai_serving_pooling
=
(
state
.
openai_serving_pooling
=
(
OpenAIServingPooling
(
OpenAIServingPooling
(
engine_client
,
engine_client
,
vllm_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
...
@@ -1768,7 +1762,6 @@ async def init_app_state(
...
@@ -1768,7 +1762,6 @@ async def init_app_state(
state
.
openai_serving_embedding
=
(
state
.
openai_serving_embedding
=
(
OpenAIServingEmbedding
(
OpenAIServingEmbedding
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
...
@@ -1782,7 +1775,6 @@ async def init_app_state(
...
@@ -1782,7 +1775,6 @@ async def init_app_state(
state
.
openai_serving_classification
=
(
state
.
openai_serving_classification
=
(
ServingClassification
(
ServingClassification
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
log_error_stack
=
args
.
log_error_stack
,
log_error_stack
=
args
.
log_error_stack
,
...
@@ -1793,7 +1785,6 @@ async def init_app_state(
...
@@ -1793,7 +1785,6 @@ async def init_app_state(
state
.
openai_serving_scores
=
(
state
.
openai_serving_scores
=
(
ServingScores
(
ServingScores
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
log_error_stack
=
args
.
log_error_stack
,
log_error_stack
=
args
.
log_error_stack
,
...
@@ -1803,7 +1794,6 @@ async def init_app_state(
...
@@ -1803,7 +1794,6 @@ async def init_app_state(
)
)
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
...
@@ -1814,7 +1804,6 @@ async def init_app_state(
...
@@ -1814,7 +1804,6 @@ async def init_app_state(
state
.
openai_serving_transcription
=
(
state
.
openai_serving_transcription
=
(
OpenAIServingTranscription
(
OpenAIServingTranscription
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
log_error_stack
=
args
.
log_error_stack
,
log_error_stack
=
args
.
log_error_stack
,
...
@@ -1825,7 +1814,6 @@ async def init_app_state(
...
@@ -1825,7 +1814,6 @@ async def init_app_state(
state
.
openai_serving_translation
=
(
state
.
openai_serving_translation
=
(
OpenAIServingTranslation
(
OpenAIServingTranslation
(
engine_client
,
engine_client
,
model_config
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
log_error_stack
=
args
.
log_error_stack
,
log_error_stack
=
args
.
log_error_stack
,
...
@@ -1946,12 +1934,11 @@ async def run_server_worker(
...
@@ -1946,12 +1934,11 @@ async def run_server_worker(
maybe_register_tokenizer_info_endpoint
(
args
)
maybe_register_tokenizer_info_endpoint
(
args
)
app
=
build_app
(
args
)
app
=
build_app
(
args
)
vllm_config
=
await
engine_client
.
get_vllm_config
()
await
init_app_state
(
engine_client
,
app
.
state
,
args
)
await
init_app_state
(
engine_client
,
vllm_config
,
app
.
state
,
args
)
logger
.
info
(
logger
.
info
(
"Starting vLLM API server %d on %s"
,
"Starting vLLM API server %d on %s"
,
vllm_config
.
parallel_config
.
_api_process_rank
,
engine_client
.
vllm_config
.
parallel_config
.
_api_process_rank
,
listen_address
,
listen_address
,
)
)
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
...
...
vllm/entrypoints/openai/run_batch.py
View file @
4bdf7ac5
...
@@ -14,7 +14,6 @@ import torch
...
@@ -14,7 +14,6 @@ import torch
from
prometheus_client
import
start_http_server
from
prometheus_client
import
start_http_server
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
optional_type
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
optional_type
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -328,7 +327,6 @@ async def run_request(
...
@@ -328,7 +327,6 @@ async def run_request(
async
def
run_batch
(
async
def
run_batch
(
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
vllm_config
:
VllmConfig
,
args
:
Namespace
,
args
:
Namespace
,
)
->
None
:
)
->
None
:
if
args
.
served_model_name
is
not
None
:
if
args
.
served_model_name
is
not
None
:
...
@@ -345,22 +343,19 @@ async def run_batch(
...
@@ -345,22 +343,19 @@ async def run_batch(
BaseModelPath
(
name
=
name
,
model_path
=
args
.
model
)
for
name
in
served_model_names
BaseModelPath
(
name
=
name
,
model_path
=
args
.
model
)
for
name
in
served_model_names
]
]
model_config
=
vllm_config
.
model_config
model_config
=
engine_client
.
model_config
supported_tasks
=
await
engine_client
.
get_supported_tasks
()
supported_tasks
=
await
engine_client
.
get_supported_tasks
()
logger
.
info
(
"Supported
_
tasks: %s"
,
supported_tasks
)
logger
.
info
(
"Supported
tasks: %s"
,
supported_tasks
)
# Create the openai serving objects.
# Create the openai serving objects.
openai_serving_models
=
OpenAIServingModels
(
openai_serving_models
=
OpenAIServingModels
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
base_model_paths
=
base_model_paths
,
lora_modules
=
None
,
lora_modules
=
None
,
)
)
openai_serving_chat
=
(
openai_serving_chat
=
(
OpenAIServingChat
(
OpenAIServingChat
(
engine_client
,
engine_client
,
model_config
,
openai_serving_models
,
openai_serving_models
,
args
.
response_role
,
args
.
response_role
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
...
@@ -374,7 +369,6 @@ async def run_batch(
...
@@ -374,7 +369,6 @@ async def run_batch(
openai_serving_embedding
=
(
openai_serving_embedding
=
(
OpenAIServingEmbedding
(
OpenAIServingEmbedding
(
engine_client
,
engine_client
,
model_config
,
openai_serving_models
,
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
None
,
chat_template
=
None
,
...
@@ -392,7 +386,6 @@ async def run_batch(
...
@@ -392,7 +386,6 @@ async def run_batch(
openai_serving_scores
=
(
openai_serving_scores
=
(
ServingScores
(
ServingScores
(
engine_client
,
engine_client
,
model_config
,
openai_serving_models
,
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
)
)
...
@@ -509,9 +502,7 @@ async def main(args: Namespace):
...
@@ -509,9 +502,7 @@ async def main(args: Namespace):
usage_context
=
UsageContext
.
OPENAI_BATCH_RUNNER
,
usage_context
=
UsageContext
.
OPENAI_BATCH_RUNNER
,
disable_frontend_multiprocessing
=
False
,
disable_frontend_multiprocessing
=
False
,
)
as
engine_client
:
)
as
engine_client
:
vllm_config
=
await
engine_client
.
get_vllm_config
()
await
run_batch
(
engine_client
,
args
)
await
run_batch
(
engine_client
,
vllm_config
,
args
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
4bdf7ac5
...
@@ -15,7 +15,6 @@ from fastapi import Request
...
@@ -15,7 +15,6 @@ from fastapi import Request
from
openai_harmony
import
Message
as
OpenAIMessage
from
openai_harmony
import
Message
as
OpenAIMessage
from
pydantic
import
TypeAdapter
from
pydantic
import
TypeAdapter
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
from
vllm.entrypoints.chat_utils
import
(
ChatTemplateContentFormatOption
,
ChatTemplateContentFormatOption
,
...
@@ -81,7 +80,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -81,7 +80,6 @@ class OpenAIServingChat(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
models
:
OpenAIServingModels
,
models
:
OpenAIServingModels
,
response_role
:
str
,
response_role
:
str
,
*
,
*
,
...
@@ -101,7 +99,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -101,7 +99,6 @@ class OpenAIServingChat(OpenAIServing):
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
...
@@ -138,7 +135,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -138,7 +135,7 @@ class OpenAIServingChat(OpenAIServing):
self
.
tool_parser
:
Optional
[
Callable
[[
AnyTokenizer
],
ToolParser
]]
=
None
self
.
tool_parser
:
Optional
[
Callable
[[
AnyTokenizer
],
ToolParser
]]
=
None
if
self
.
enable_auto_tools
:
if
self
.
enable_auto_tools
:
try
:
try
:
if
tool_parser
==
"pythonic"
and
model_config
.
model
.
startswith
(
if
tool_parser
==
"pythonic"
and
self
.
model_config
.
model
.
startswith
(
"meta-llama/Llama-3.2"
"meta-llama/Llama-3.2"
):
):
logger
.
warning
(
logger
.
warning
(
...
@@ -169,7 +166,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -169,7 +166,7 @@ class OpenAIServingChat(OpenAIServing):
else
:
else
:
self
.
tool_call_id_type
=
"random"
self
.
tool_call_id_type
=
"random"
self
.
use_harmony
=
model_config
.
hf_config
.
model_type
==
"gpt_oss"
self
.
use_harmony
=
self
.
model_config
.
hf_config
.
model_type
==
"gpt_oss"
if
self
.
use_harmony
:
if
self
.
use_harmony
:
if
"stop_token_ids"
not
in
self
.
default_sampling_params
:
if
"stop_token_ids"
not
in
self
.
default_sampling_params
:
self
.
default_sampling_params
[
"stop_token_ids"
]
=
[]
self
.
default_sampling_params
[
"stop_token_ids"
]
=
[]
...
@@ -338,7 +335,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -338,7 +335,7 @@ class OpenAIServingChat(OpenAIServing):
)
)
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
generator
=
self
.
engine_client
.
beam_search
(
generator
=
self
.
beam_search
(
prompt
=
engine_prompt
,
prompt
=
engine_prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
params
=
sampling_params
,
params
=
sampling_params
,
...
...
vllm/entrypoints/openai/serving_classification.py
View file @
4bdf7ac5
...
@@ -8,7 +8,6 @@ import numpy as np
...
@@ -8,7 +8,6 @@ import numpy as np
from
fastapi
import
Request
from
fastapi
import
Request
from
typing_extensions
import
override
from
typing_extensions
import
override
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
...
@@ -128,7 +127,6 @@ class ServingClassification(ClassificationMixin):
...
@@ -128,7 +127,6 @@ class ServingClassification(ClassificationMixin):
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
models
:
OpenAIServingModels
,
models
:
OpenAIServingModels
,
*
,
*
,
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
...
@@ -136,7 +134,6 @@ class ServingClassification(ClassificationMixin):
...
@@ -136,7 +134,6 @@ class ServingClassification(ClassificationMixin):
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
log_error_stack
=
log_error_stack
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
4bdf7ac5
...
@@ -10,7 +10,6 @@ from typing import Optional, Union, cast
...
@@ -10,7 +10,6 @@ from typing import Optional, Union, cast
import
jinja2
import
jinja2
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
...
@@ -44,7 +43,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -44,7 +43,6 @@ class OpenAIServingCompletion(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
models
:
OpenAIServingModels
,
models
:
OpenAIServingModels
,
*
,
*
,
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
...
@@ -55,7 +53,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -55,7 +53,6 @@ class OpenAIServingCompletion(OpenAIServing):
):
):
super
().
__init__
(
super
().
__init__
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
...
@@ -201,7 +198,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -201,7 +198,7 @@ class OpenAIServingCompletion(OpenAIServing):
# but pre-commit in CI fails without it.
# but pre-commit in CI fails without it.
engine_prompt
=
cast
(
Union
[
EmbedsPrompt
,
TokensPrompt
],
engine_prompt
)
engine_prompt
=
cast
(
Union
[
EmbedsPrompt
,
TokensPrompt
],
engine_prompt
)
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
generator
=
self
.
engine_client
.
beam_search
(
generator
=
self
.
beam_search
(
prompt
=
engine_prompt
,
prompt
=
engine_prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
params
=
sampling_params
,
params
=
sampling_params
,
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
4bdf7ac5
...
@@ -10,7 +10,6 @@ import torch
...
@@ -10,7 +10,6 @@ import torch
from
fastapi
import
Request
from
fastapi
import
Request
from
typing_extensions
import
assert_never
,
override
from
typing_extensions
import
assert_never
,
override
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -597,7 +596,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
...
@@ -597,7 +596,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
models
:
OpenAIServingModels
,
models
:
OpenAIServingModels
,
*
,
*
,
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
...
@@ -608,7 +606,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
...
@@ -608,7 +606,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
log_error_stack
=
log_error_stack
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
4bdf7ac5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
json
import
json
import
sys
import
sys
import
time
import
time
...
@@ -15,17 +16,13 @@ from pydantic import BaseModel, ConfigDict, Field
...
@@ -15,17 +16,13 @@ from pydantic import BaseModel, ConfigDict, Field
from
starlette.datastructures
import
Headers
from
starlette.datastructures
import
Headers
from
typing_extensions
import
TypeIs
from
typing_extensions
import
TypeIs
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.processor
import
Processor
if
sys
.
version_info
>=
(
3
,
12
):
if
sys
.
version_info
>=
(
3
,
12
):
from
typing
import
TypedDict
from
typing
import
TypedDict
else
:
else
:
from
typing_extensions
import
TypedDict
from
typing_extensions
import
TypedDict
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.
config
import
ModelConfig
from
vllm.
beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
ChatCompletionMessageParam
,
...
@@ -68,9 +65,14 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -68,9 +65,14 @@ from vllm.entrypoints.openai.protocol import (
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
from
vllm.entrypoints.renderer
import
BaseRenderer
,
CompletionRenderer
,
RenderConfig
from
vllm.entrypoints.renderer
import
BaseRenderer
,
CompletionRenderer
,
RenderConfig
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.inputs.parse
import
PromptComponents
,
get_prompt_components
from
vllm.inputs.parse
import
(
PromptComponents
,
get_prompt_components
,
is_explicit_encoder_decoder_prompt
,
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -78,7 +80,7 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
...
@@ -78,7 +80,7 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
MultiModalDataDict
,
MultiModalDataDict
,
MultiModalUUIDDict
,
MultiModalUUIDDict
,
)
)
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.tracing
import
(
from
vllm.tracing
import
(
...
@@ -89,11 +91,13 @@ from vllm.tracing import (
...
@@ -89,11 +91,13 @@ from vllm.tracing import (
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
(
from
vllm.utils
import
(
AsyncMicrobatchTokenizer
,
AsyncMicrobatchTokenizer
,
collect_from_async_generator
,
is_list_of
,
is_list_of
,
make_async
,
make_async
,
merge_async_iterators
,
merge_async_iterators
,
random_uuid
,
random_uuid
,
)
)
from
vllm.v1.engine
import
EngineCoreRequest
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -240,7 +244,6 @@ class OpenAIServing:
...
@@ -240,7 +244,6 @@ class OpenAIServing:
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
models
:
OpenAIServingModels
,
models
:
OpenAIServingModels
,
*
,
*
,
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
...
@@ -251,8 +254,6 @@ class OpenAIServing:
...
@@ -251,8 +254,6 @@ class OpenAIServing:
super
().
__init__
()
super
().
__init__
()
self
.
engine_client
=
engine_client
self
.
engine_client
=
engine_client
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
self
.
models
=
models
self
.
models
=
models
...
@@ -268,12 +269,194 @@ class OpenAIServing:
...
@@ -268,12 +269,194 @@ class OpenAIServing:
self
.
_async_tokenizer_pool
:
dict
[
AnyTokenizer
,
AsyncMicrobatchTokenizer
]
=
{}
self
.
_async_tokenizer_pool
:
dict
[
AnyTokenizer
,
AsyncMicrobatchTokenizer
]
=
{}
self
.
log_error_stack
=
log_error_stack
self
.
log_error_stack
=
log_error_stack
async
def
_get_processor
(
self
)
->
Processor
:
self
.
processor
=
self
.
models
.
processor
if
not
hasattr
(
self
,
"_processor"
):
self
.
io_processor
=
self
.
models
.
io_processor
vllm_config
=
await
self
.
engine_client
.
get_vllm_config
()
self
.
model_config
=
self
.
models
.
model_config
self
.
_processor
=
Processor
(
vllm_config
)
self
.
max_model_len
=
self
.
model_config
.
max_model_len
async
def
beam_search
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
params
:
BeamSearchParams
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
include_stop_str_in_output
=
params
.
include_stop_str_in_output
processor
=
self
.
processor
tokenizer
=
processor
.
tokenizer
if
tokenizer
is
None
:
raise
ValueError
(
"You cannot use beam search when `skip_tokenizer_init` is True"
)
eos_token_id
:
int
=
tokenizer
.
eos_token_id
# type: ignore
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
NotImplementedError
else
:
processed_inputs
=
processor
.
input_preprocessor
.
_prompt_to_llm_inputs
(
prompt
)
if
processed_inputs
[
"type"
]
==
"embeds"
:
raise
NotImplementedError
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_text
:
Optional
[
str
]
prompt_token_ids
:
list
[
int
]
multi_modal_data
:
Optional
[
MultiModalDataDict
]
if
isinstance
(
prompt
,
str
):
prompt_text
=
prompt
prompt_token_ids
=
[]
multi_modal_data
=
None
else
:
prompt_text
=
prompt
.
get
(
"prompt"
)
# type: ignore
prompt_token_ids
=
prompt
.
get
(
"prompt_token_ids"
,
[])
# type: ignore
multi_modal_data
=
prompt
.
get
(
"multi_modal_data"
)
# type: ignore
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
processed_inputs
.
get
(
"mm_processor_kwargs"
)
# type: ignore
tokenized_length
=
len
(
prompt_token_ids
)
sort_beams_key
=
create_sort_beams_key_function
(
eos_token_id
,
length_penalty
)
return
self
.
_processor
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
,
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
prompt_token_ids
,
cum_logprob
=
0
,
logprobs
=
[],
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
lora_request
=
lora_request
,
)
]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
,
lora_req_batch
=
zip
(
*
[
(
EngineTokensPrompt
(
prompt_token_ids
=
beam
.
tokens
,
multi_modal_data
=
beam
.
multi_modal_data
,
mm_processor_kwargs
=
beam
.
mm_processor_kwargs
,
),
beam
.
lora_request
,
)
for
beam
in
all_beams
]
)
tasks
=
[]
request_id_batch
=
f
"
{
request_id
}
-
{
random_uuid
()
}
"
for
i
,
(
individual_prompt
,
lora_req
)
in
enumerate
(
zip
(
prompts_batch
,
lora_req_batch
)
):
request_id_item
=
f
"
{
request_id_batch
}
-beam-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
engine_client
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
,
lora_request
=
lora_req
,
)
)
)
tasks
.
append
(
task
)
output
=
[
x
[
0
]
for
x
in
await
asyncio
.
gather
(
*
tasks
)]
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
if
token_id
==
eos_token_id
and
not
ignore_eos
:
completed
.
append
(
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
]
if
include_stop_str_in_output
else
current_beam
.
tokens
,
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
finish_reason
=
"stop"
,
stop_reason
=
eos_token_id
,
)
)
else
:
new_beams
.
append
(
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
lora_request
=
current_beam
.
lora_request
,
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
multi_modal_data
=
current_beam
.
multi_modal_data
,
mm_processor_kwargs
=
current_beam
.
mm_processor_kwargs
,
)
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
if
beam
.
tokens
[
-
1
]
==
eos_token_id
and
not
ignore_eos
:
# Skip the eos token in the text.
tokens
=
beam
.
tokens
[
tokenized_length
:
-
1
]
else
:
tokens
=
beam
.
tokens
[
tokenized_length
:]
beam
.
text
=
tokenizer
.
decode
(
tokens
)
yield
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt_text
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
# type: ignore
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
[
tokenized_length
:],
index
=
i
,
logprobs
=
beam
.
logprobs
,
finish_reason
=
beam
.
finish_reason
if
beam
.
finish_reason
is
not
None
else
"length"
,
stop_reason
=
beam
.
stop_reason
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
None
,
)
def
_get_renderer
(
self
,
tokenizer
:
Optional
[
AnyTokenizer
])
->
BaseRenderer
:
def
_get_renderer
(
self
,
tokenizer
:
Optional
[
AnyTokenizer
])
->
BaseRenderer
:
"""
"""
...
@@ -938,8 +1121,7 @@ class OpenAIServing:
...
@@ -938,8 +1121,7 @@ class OpenAIServing:
self
.
max_model_len
,
params
.
truncate_prompt_tokens
,
tokenization_kwargs
self
.
max_model_len
,
params
.
truncate_prompt_tokens
,
tokenization_kwargs
)
)
processor
=
await
self
.
_get_processor
()
engine_request
=
self
.
processor
.
process_inputs
(
engine_request
=
processor
.
process_inputs
(
request_id
,
request_id
,
engine_prompt
,
engine_prompt
,
params
,
params
,
...
...
vllm/entrypoints/openai/serving_models.py
View file @
4bdf7ac5
...
@@ -7,7 +7,6 @@ from dataclasses import dataclass
...
@@ -7,7 +7,6 @@ from dataclasses import dataclass
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ErrorInfo
,
ErrorInfo
,
...
@@ -51,18 +50,14 @@ class OpenAIServingModels:
...
@@ -51,18 +50,14 @@ class OpenAIServingModels:
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
base_model_paths
:
list
[
BaseModelPath
],
base_model_paths
:
list
[
BaseModelPath
],
*
,
*
,
lora_modules
:
Optional
[
list
[
LoRAModulePath
]]
=
None
,
lora_modules
:
Optional
[
list
[
LoRAModulePath
]]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
base_model_paths
=
base_model_paths
self
.
max_model_len
=
model_config
.
max_model_len
self
.
engine_client
=
engine_client
self
.
engine_client
=
engine_client
self
.
model_
config
=
model_config
self
.
base_
model_
paths
=
base_model_paths
self
.
static_lora_modules
=
lora_modules
self
.
static_lora_modules
=
lora_modules
self
.
lora_requests
:
dict
[
str
,
LoRARequest
]
=
{}
self
.
lora_requests
:
dict
[
str
,
LoRARequest
]
=
{}
...
@@ -75,6 +70,11 @@ class OpenAIServingModels:
...
@@ -75,6 +70,11 @@ class OpenAIServingModels:
)
)
self
.
lora_resolver_lock
:
dict
[
str
,
Lock
]
=
defaultdict
(
Lock
)
self
.
lora_resolver_lock
:
dict
[
str
,
Lock
]
=
defaultdict
(
Lock
)
self
.
processor
=
self
.
engine_client
.
processor
self
.
io_processor
=
self
.
engine_client
.
io_processor
self
.
model_config
=
self
.
engine_client
.
model_config
self
.
max_model_len
=
self
.
model_config
.
max_model_len
async
def
init_static_loras
(
self
):
async
def
init_static_loras
(
self
):
"""Loads all static LoRA modules.
"""Loads all static LoRA modules.
Raises if any fail to load"""
Raises if any fail to load"""
...
...
vllm/entrypoints/openai/serving_pooling.py
View file @
4bdf7ac5
...
@@ -13,7 +13,6 @@ import torch
...
@@ -13,7 +13,6 @@ import torch
from
fastapi
import
Request
from
fastapi
import
Request
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
from
vllm.config
import
VllmConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -34,7 +33,6 @@ from vllm.entrypoints.renderer import RenderConfig
...
@@ -34,7 +33,6 @@ from vllm.entrypoints.renderer import RenderConfig
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.utils
import
merge_async_iterators
from
vllm.utils
import
merge_async_iterators
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -60,7 +58,6 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -60,7 +58,6 @@ class OpenAIServingPooling(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
vllm_config
:
VllmConfig
,
models
:
OpenAIServingModels
,
models
:
OpenAIServingModels
,
*
,
*
,
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
...
@@ -71,7 +68,6 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -71,7 +68,6 @@ class OpenAIServingPooling(OpenAIServing):
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
model_config
=
vllm_config
.
model_config
,
models
=
models
,
models
=
models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
log_error_stack
=
log_error_stack
,
...
@@ -80,8 +76,6 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -80,8 +76,6 @@ class OpenAIServingPooling(OpenAIServing):
self
.
chat_template
=
chat_template
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
self
.
trust_request_chat_template
=
trust_request_chat_template
io_processor_plugin
=
self
.
model_config
.
io_processor_plugin
self
.
io_processor
=
get_io_processor
(
vllm_config
,
io_processor_plugin
)
async
def
create_pooling
(
async
def
create_pooling
(
self
,
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment