Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96ae75ad
Commit
96ae75ad
authored
Jan 04, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev
parents
f9f4a735
2339d59f
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
716 additions
and
177 deletions
+716
-177
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+1
-1
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+3
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+11
-3
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+84
-19
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+5
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+137
-31
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+1
-1
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+10
-2
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+11
-2
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+45
-34
vllm/entrypoints/openai/serving_pooling.py
vllm/entrypoints/openai/serving_pooling.py
+234
-0
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+42
-29
vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
+10
-2
vllm/envs.py
vllm/envs.py
+3
-2
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+1
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+22
-13
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+1
-1
vllm/inputs/data.py
vllm/inputs/data.py
+20
-0
vllm/inputs/registry.py
vllm/inputs/registry.py
+73
-33
vllm/lora/layers.py
vllm/lora/layers.py
+2
-1
No files found.
vllm/engine/output_processor/multi_step.py
View file @
96ae75ad
...
@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@
staticmethod
@
staticmethod
@
functools
.
lru_cache
@
functools
.
lru_cache
def
_log_prompt_logprob_unsupported_warning_once
():
def
_log_prompt_logprob_unsupported_warning_once
():
# Reminder: Please update docs/source/usage/compatibility_matrix.
rst
# Reminder: Please update docs/source/usage/compatibility_matrix.
md
# If the feature combo become valid
# If the feature combo become valid
logger
.
warning
(
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
"Prompt logprob is not supported by multi step workers. "
...
...
vllm/entrypoints/api_server.py
View file @
96ae75ad
...
@@ -21,7 +21,7 @@ from vllm.entrypoints.utils import with_cancellation
...
@@ -21,7 +21,7 @@ from vllm.entrypoints.utils import with_cancellation
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
,
set_ulimit
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
@@ -119,6 +119,8 @@ async def run_server(args: Namespace,
...
@@ -119,6 +119,8 @@ async def run_server(args: Namespace,
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
logger
.
info
(
"args: %s"
,
args
)
set_ulimit
()
app
=
await
init_app
(
args
,
llm_engine
)
app
=
await
init_app
(
args
,
llm_engine
)
assert
engine
is
not
None
assert
engine
is
not
None
...
...
vllm/entrypoints/llm.py
View file @
96ae75ad
...
@@ -115,7 +115,7 @@ class LLM:
...
@@ -115,7 +115,7 @@ class LLM:
integer, it is used as the level of compilation optimization. If it
integer, it is used as the level of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
is a dictionary, it can specify the full compilation configuration.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine
_
args`)
:ref:`engine
-
args`)
Note:
Note:
This class is intended to be used for offline inference. For online
This class is intended to be used for offline inference. For online
...
@@ -233,7 +233,8 @@ class LLM:
...
@@ -233,7 +233,8 @@ class LLM:
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
llm_engine
and
hasattr
(
self
.
llm_engine
,
"shutdown"
):
if
hasattr
(
self
,
'llm_engine'
)
and
self
.
llm_engine
and
hasattr
(
self
.
llm_engine
,
"shutdown"
):
self
.
llm_engine
.
shutdown
()
self
.
llm_engine
.
shutdown
()
@
staticmethod
@
staticmethod
...
@@ -258,6 +259,13 @@ class LLM:
...
@@ -258,6 +259,13 @@ class LLM:
else
:
else
:
tokenizer_group
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
tokenizer_group
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
def
get_default_sampling_params
(
self
)
->
SamplingParams
:
diff_sampling_param
=
(
self
.
llm_engine
.
model_config
.
get_diff_sampling_param
())
if
diff_sampling_param
:
return
SamplingParams
.
from_optional
(
**
diff_sampling_param
)
return
SamplingParams
()
@
overload
@
overload
def
generate
(
def
generate
(
self
,
self
,
...
@@ -441,7 +449,7 @@ class LLM:
...
@@ -441,7 +449,7 @@ class LLM:
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
S
ampling
P
arams
()
sampling_params
=
self
.
get_default_s
ampling
_p
arams
()
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
prompts
=
parsed_prompts
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
96ae75ad
...
@@ -27,6 +27,7 @@ from typing_extensions import assert_never
...
@@ -27,6 +27,7 @@ from typing_extensions import assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# type: ignore
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.multiprocessing.engine
import
run_mp_engine
from
vllm.engine.multiprocessing.engine
import
run_mp_engine
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
...
@@ -44,8 +45,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -44,8 +45,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest
,
DetokenizeRequest
,
DetokenizeResponse
,
DetokenizeResponse
,
EmbeddingRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
EmbeddingResponse
,
EmbeddingResponseData
,
ErrorResponse
,
LoadLoraAdapterRequest
,
LoadLoraAdapterRequest
,
PoolingRequest
,
PoolingResponse
,
ScoreRequest
,
ScoreResponse
,
ScoreRequest
,
ScoreResponse
,
TokenizeRequest
,
TokenizeRequest
,
TokenizeResponse
,
TokenizeResponse
,
...
@@ -55,6 +59,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
...
@@ -55,6 +59,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.entrypoints.openai.serving_pooling
import
OpenAIServingPooling
from
vllm.entrypoints.openai.serving_score
import
OpenAIServingScores
from
vllm.entrypoints.openai.serving_score
import
OpenAIServingScores
from
vllm.entrypoints.openai.serving_tokenization
import
(
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
OpenAIServingTokenization
)
...
@@ -63,14 +68,9 @@ from vllm.entrypoints.utils import with_cancellation
...
@@ -63,14 +68,9 @@ from vllm.entrypoints.utils import with_cancellation
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
)
is_valid_ipv6_address
,
set_ulimit
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
if
envs
.
VLLM_USE_V1
:
from
vllm.v1.engine.async_llm
import
AsyncLLMEngine
# type: ignore
else
:
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# type: ignore
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
...
@@ -288,6 +288,10 @@ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
...
@@ -288,6 +288,10 @@ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return
request
.
app
.
state
.
openai_serving_completion
return
request
.
app
.
state
.
openai_serving_completion
def
pooling
(
request
:
Request
)
->
Optional
[
OpenAIServingPooling
]:
return
request
.
app
.
state
.
openai_serving_pooling
def
embedding
(
request
:
Request
)
->
Optional
[
OpenAIServingEmbedding
]:
def
embedding
(
request
:
Request
)
->
Optional
[
OpenAIServingEmbedding
]:
return
request
.
app
.
state
.
openai_serving_embedding
return
request
.
app
.
state
.
openai_serving_embedding
...
@@ -399,10 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -399,10 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
handler
=
embedding
(
raw_request
)
handler
=
embedding
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
fallback_handler
=
pooling
(
raw_request
)
message
=
"The model does not support Embeddings API"
)
if
fallback_handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support Embeddings API"
)
logger
.
warning
(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead."
)
res
=
await
fallback_handler
.
create_pooling
(
request
,
raw_request
)
if
isinstance
(
res
,
PoolingResponse
):
generator
=
EmbeddingResponse
(
id
=
res
.
id
,
object
=
res
.
object
,
created
=
res
.
created
,
model
=
res
.
model
,
data
=
[
EmbeddingResponseData
(
index
=
d
.
index
,
embedding
=
d
.
data
,
# type: ignore
)
for
d
in
res
.
data
],
usage
=
res
.
usage
,
)
else
:
generator
=
res
else
:
generator
=
await
handler
.
create_embedding
(
request
,
raw_request
)
generator
=
await
handler
.
create_embedding
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
status_code
=
generator
.
code
)
...
@@ -412,6 +442,24 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
...
@@ -412,6 +442,24 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never
(
generator
)
assert_never
(
generator
)
@
router
.
post
(
"/pooling"
)
@
with_cancellation
async
def
create_pooling
(
request
:
PoolingRequest
,
raw_request
:
Request
):
handler
=
pooling
(
raw_request
)
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support Pooling API"
)
generator
=
await
handler
.
create_pooling
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
elif
isinstance
(
generator
,
PoolingResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
"/score"
)
@
router
.
post
(
"/score"
)
@
with_cancellation
@
with_cancellation
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
...
@@ -537,12 +585,18 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -537,12 +585,18 @@ def build_app(args: Namespace) -> FastAPI:
status_code
=
401
)
status_code
=
401
)
return
await
call_next
(
request
)
return
await
call_next
(
request
)
@
app
.
middleware
(
"http"
)
if
args
.
enable_request_id_headers
:
async
def
add_request_id
(
request
:
Request
,
call_next
):
logger
.
warning
(
request_id
=
request
.
headers
.
get
(
"X-Request-Id"
)
or
uuid
.
uuid4
().
hex
"CAUTION: Enabling X-Request-Id headers in the API Server. "
response
=
await
call_next
(
request
)
"This can harm performance at high QPS."
)
response
.
headers
[
"X-Request-Id"
]
=
request_id
return
response
@
app
.
middleware
(
"http"
)
async
def
add_request_id
(
request
:
Request
,
call_next
):
request_id
=
request
.
headers
.
get
(
"X-Request-Id"
)
or
uuid
.
uuid4
().
hex
response
=
await
call_next
(
request
)
response
.
headers
[
"X-Request-Id"
]
=
request_id
return
response
for
middleware
in
args
.
middleware
:
for
middleware
in
args
.
middleware
:
module_path
,
object_name
=
middleware
.
rsplit
(
"."
,
1
)
module_path
,
object_name
=
middleware
.
rsplit
(
"."
,
1
)
...
@@ -609,7 +663,7 @@ def init_app_state(
...
@@ -609,7 +663,7 @@ def init_app_state(
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
if
model_config
.
runner_type
==
"generate"
else
None
)
if
model_config
.
runner_type
==
"generate"
else
None
state
.
openai_serving_
embedd
ing
=
OpenAIServing
Embedd
ing
(
state
.
openai_serving_
pool
ing
=
OpenAIServing
Pool
ing
(
engine_client
,
engine_client
,
model_config
,
model_config
,
base_model_paths
,
base_model_paths
,
...
@@ -617,13 +671,20 @@ def init_app_state(
...
@@ -617,13 +671,20 @@ def init_app_state(
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
chat_template_content_format
=
args
.
chat_template_content_format
,
)
if
model_config
.
runner_type
==
"pooling"
else
None
)
if
model_config
.
runner_type
==
"pooling"
else
None
state
.
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine_client
,
model_config
,
base_model_paths
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
)
if
model_config
.
task
==
"embed"
else
None
state
.
openai_serving_scores
=
OpenAIServingScores
(
state
.
openai_serving_scores
=
OpenAIServingScores
(
engine_client
,
engine_client
,
model_config
,
model_config
,
base_model_paths
,
base_model_paths
,
request_logger
=
request_logger
request_logger
=
request_logger
)
if
(
model_config
.
runner_type
==
"pooling"
\
)
if
model_config
.
task
==
"score"
else
None
and
model_config
.
is_cross_encoder
)
else
None
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine_client
,
engine_client
,
model_config
,
model_config
,
...
@@ -666,6 +727,10 @@ async def run_server(args, **uvicorn_kwargs) -> None:
...
@@ -666,6 +727,10 @@ async def run_server(args, **uvicorn_kwargs) -> None:
sock_addr
=
(
args
.
host
or
""
,
args
.
port
)
sock_addr
=
(
args
.
host
or
""
,
args
.
port
)
sock
=
create_server_socket
(
sock_addr
)
sock
=
create_server_socket
(
sock_addr
)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit
()
def
signal_handler
(
*
_
)
->
None
:
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm while initializing
# Interrupt server on sigterm while initializing
raise
KeyboardInterrupt
(
"terminated"
)
raise
KeyboardInterrupt
(
"terminated"
)
...
...
vllm/entrypoints/openai/cli_args.py
View file @
96ae75ad
...
@@ -196,7 +196,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -196,7 +196,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"If specified, will run the OpenAI frontend server in the same "
help
=
"If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine."
)
"process as the model serving engine."
)
parser
.
add_argument
(
"--enable-request-id-headers"
,
action
=
"store_true"
,
help
=
"If specified, API server will add X-Request-Id header to "
"responses. Caution: this hurts performance at high QPS."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-auto-tool-choice"
,
"--enable-auto-tool-choice"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
96ae75ad
...
@@ -46,7 +46,15 @@ class OpenAIBaseModel(BaseModel):
...
@@ -46,7 +46,15 @@ class OpenAIBaseModel(BaseModel):
@
classmethod
@
classmethod
def
__log_extra_fields__
(
cls
,
data
):
def
__log_extra_fields__
(
cls
,
data
):
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
extra_fields
=
data
.
keys
()
-
cls
.
model_fields
.
keys
()
# Get all class field names and their potential aliases
field_names
=
set
()
for
field_name
,
field
in
cls
.
model_fields
.
items
():
field_names
.
add
(
field_name
)
if
hasattr
(
field
,
'alias'
)
and
field
.
alias
:
field_names
.
add
(
field
.
alias
)
# Compare against both field names and aliases
extra_fields
=
data
.
keys
()
-
field_names
if
extra_fields
:
if
extra_fields
:
logger
.
warning
(
logger
.
warning
(
"The following fields were present in the request "
"The following fields were present in the request "
...
@@ -211,8 +219,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -211,8 +219,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
stream_options
:
Optional
[
StreamOptions
]
=
None
temperature
:
Optional
[
float
]
=
1.0
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
None
tools
:
Optional
[
List
[
ChatCompletionToolsParam
]]
=
None
tools
:
Optional
[
List
[
ChatCompletionToolsParam
]]
=
None
tool_choice
:
Optional
[
Union
[
Literal
[
"none"
],
Literal
[
"auto"
],
tool_choice
:
Optional
[
Union
[
Literal
[
"none"
],
Literal
[
"auto"
],
ChatCompletionNamedToolChoiceParam
]]
=
"none"
ChatCompletionNamedToolChoiceParam
]]
=
"none"
...
@@ -224,9 +232,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -224,9 +232,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params
# doc: begin-chat-completion-sampling-params
best_of
:
Optional
[
int
]
=
None
best_of
:
Optional
[
int
]
=
None
use_beam_search
:
bool
=
False
use_beam_search
:
bool
=
False
top_k
:
int
=
-
1
top_k
:
Optional
[
int
]
=
None
min_p
:
float
=
0.0
min_p
:
Optional
[
float
]
=
None
repetition_penalty
:
float
=
1.0
repetition_penalty
:
Optional
[
float
]
=
None
length_penalty
:
float
=
1.0
length_penalty
:
float
=
1.0
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
include_stop_str_in_output
:
bool
=
False
include_stop_str_in_output
:
bool
=
False
...
@@ -348,15 +356,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -348,15 +356,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
# doc: end-chat-completion-extra-params
def
to_beam_search_params
(
self
,
# Default sampling parameters for chat completion requests
default_max_tokens
:
int
)
->
BeamSearchParams
:
_DEFAULT_SAMPLING_PARAMS
:
dict
=
{
"repetition_penalty"
:
1.0
,
"temperature"
:
1.0
,
"top_p"
:
1.0
,
"top_k"
:
-
1
,
"min_p"
:
0.0
,
}
def
to_beam_search_params
(
self
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
BeamSearchParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
if
max_tokens
is
None
:
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
n
=
self
.
n
if
self
.
n
is
not
None
else
1
temperature
=
self
.
temperature
if
self
.
temperature
is
not
None
else
0.0
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
return
BeamSearchParams
(
return
BeamSearchParams
(
beam_width
=
n
,
beam_width
=
n
,
...
@@ -367,13 +392,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -367,13 +392,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output
=
self
.
include_stop_str_in_output
)
include_stop_str_in_output
=
self
.
include_stop_str_in_output
)
def
to_sampling_params
(
def
to_sampling_params
(
self
,
default_max_tokens
:
int
,
self
,
logits_processor_pattern
:
Optional
[
str
])
->
SamplingParams
:
default_max_tokens
:
int
,
logits_processor_pattern
:
Optional
[
str
],
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
if
max_tokens
is
None
:
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
repetition_penalty
=
default_sampling_params
.
get
(
"repetition_penalty"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"repetition_penalty"
],
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
if
(
top_p
:
=
self
.
top_p
)
is
None
:
top_p
=
default_sampling_params
.
get
(
"top_p"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"top_p"
])
if
(
top_k
:
=
self
.
top_k
)
is
None
:
top_k
=
default_sampling_params
.
get
(
"top_k"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"top_k"
])
if
(
min_p
:
=
self
.
min_p
)
is
None
:
min_p
=
default_sampling_params
.
get
(
"min_p"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
prompt_logprobs
=
self
.
prompt_logprobs
prompt_logprobs
=
self
.
prompt_logprobs
if
prompt_logprobs
is
None
and
self
.
echo
:
if
prompt_logprobs
is
None
and
self
.
echo
:
prompt_logprobs
=
self
.
top_logprobs
prompt_logprobs
=
self
.
top_logprobs
...
@@ -403,11 +451,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -403,11 +451,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
best_of
=
self
.
best_of
,
best_of
=
self
.
best_of
,
presence_penalty
=
self
.
presence_penalty
,
presence_penalty
=
self
.
presence_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
self
.
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
temperature
=
self
.
temperature
,
temperature
=
temperature
,
top_p
=
self
.
top_p
,
top_p
=
top_p
,
top_k
=
self
.
top_k
,
top_k
=
top_k
,
min_p
=
self
.
min_p
,
min_p
=
min_p
,
seed
=
self
.
seed
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
...
@@ -584,15 +632,15 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -584,15 +632,15 @@ class CompletionRequest(OpenAIBaseModel):
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
stream_options
:
Optional
[
StreamOptions
]
=
None
suffix
:
Optional
[
str
]
=
None
suffix
:
Optional
[
str
]
=
None
temperature
:
Optional
[
float
]
=
1.0
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
None
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
# doc: begin-completion-sampling-params
# doc: begin-completion-sampling-params
use_beam_search
:
bool
=
False
use_beam_search
:
bool
=
False
top_k
:
int
=
-
1
top_k
:
Optional
[
int
]
=
None
min_p
:
float
=
0.0
min_p
:
Optional
[
float
]
=
None
repetition_penalty
:
float
=
1.0
repetition_penalty
:
Optional
[
float
]
=
None
length_penalty
:
float
=
1.0
length_penalty
:
float
=
1.0
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
include_stop_str_in_output
:
bool
=
False
include_stop_str_in_output
:
bool
=
False
...
@@ -669,14 +717,30 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -669,14 +717,30 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
# doc: end-completion-extra-params
def
to_beam_search_params
(
self
,
# Default sampling parameters for completion requests
default_max_tokens
:
int
)
->
BeamSearchParams
:
_DEFAULT_SAMPLING_PARAMS
:
dict
=
{
"repetition_penalty"
:
1.0
,
"temperature"
:
1.0
,
"top_p"
:
1.0
,
"top_k"
:
-
1
,
"min_p"
:
0.0
,
}
def
to_beam_search_params
(
self
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
n
=
self
.
n
if
self
.
n
is
not
None
else
1
temperature
=
self
.
temperature
if
self
.
temperature
is
not
None
else
0.0
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
1.0
)
return
BeamSearchParams
(
return
BeamSearchParams
(
beam_width
=
n
,
beam_width
=
n
,
...
@@ -687,12 +751,35 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -687,12 +751,35 @@ class CompletionRequest(OpenAIBaseModel):
include_stop_str_in_output
=
self
.
include_stop_str_in_output
)
include_stop_str_in_output
=
self
.
include_stop_str_in_output
)
def
to_sampling_params
(
def
to_sampling_params
(
self
,
default_max_tokens
:
int
,
self
,
logits_processor_pattern
:
Optional
[
str
])
->
SamplingParams
:
default_max_tokens
:
int
,
logits_processor_pattern
:
Optional
[
str
],
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
repetition_penalty
=
default_sampling_params
.
get
(
"repetition_penalty"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"repetition_penalty"
],
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
if
(
top_p
:
=
self
.
top_p
)
is
None
:
top_p
=
default_sampling_params
.
get
(
"top_p"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"top_p"
])
if
(
top_k
:
=
self
.
top_k
)
is
None
:
top_k
=
default_sampling_params
.
get
(
"top_k"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"top_k"
])
if
(
min_p
:
=
self
.
min_p
)
is
None
:
min_p
=
default_sampling_params
.
get
(
"min_p"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
prompt_logprobs
=
self
.
prompt_logprobs
prompt_logprobs
=
self
.
prompt_logprobs
if
prompt_logprobs
is
None
and
self
.
echo
:
if
prompt_logprobs
is
None
and
self
.
echo
:
prompt_logprobs
=
self
.
logprobs
prompt_logprobs
=
self
.
logprobs
...
@@ -718,11 +805,11 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -718,11 +805,11 @@ class CompletionRequest(OpenAIBaseModel):
best_of
=
self
.
best_of
,
best_of
=
self
.
best_of
,
presence_penalty
=
self
.
presence_penalty
,
presence_penalty
=
self
.
presence_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
self
.
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
temperature
=
self
.
temperature
,
temperature
=
temperature
,
top_p
=
self
.
top_p
,
top_p
=
top_p
,
top_k
=
self
.
top_k
,
top_k
=
top_k
,
min_p
=
self
.
min_p
,
min_p
=
min_p
,
seed
=
self
.
seed
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
...
@@ -876,6 +963,10 @@ class EmbeddingChatRequest(OpenAIBaseModel):
...
@@ -876,6 +963,10 @@ class EmbeddingChatRequest(OpenAIBaseModel):
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
PoolingCompletionRequest
=
EmbeddingCompletionRequest
PoolingChatRequest
=
EmbeddingChatRequest
PoolingRequest
=
Union
[
PoolingCompletionRequest
,
PoolingChatRequest
]
class
ScoreRequest
(
OpenAIBaseModel
):
class
ScoreRequest
(
OpenAIBaseModel
):
model
:
str
model
:
str
...
@@ -971,6 +1062,21 @@ class EmbeddingResponse(OpenAIBaseModel):
...
@@ -971,6 +1062,21 @@ class EmbeddingResponse(OpenAIBaseModel):
usage
:
UsageInfo
usage
:
UsageInfo
class
PoolingResponseData
(
OpenAIBaseModel
):
index
:
int
object
:
str
=
"pooling"
data
:
Union
[
List
[
List
[
float
]],
List
[
float
],
str
]
class
PoolingResponse
(
OpenAIBaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"pool-
{
random_uuid
()
}
"
)
object
:
str
=
"list"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
data
:
List
[
PoolingResponseData
]
usage
:
UsageInfo
class
ScoreResponseData
(
OpenAIBaseModel
):
class
ScoreResponseData
(
OpenAIBaseModel
):
index
:
int
index
:
int
object
:
str
=
"score"
object
:
str
=
"score"
...
...
vllm/entrypoints/openai/run_batch.py
View file @
96ae75ad
...
@@ -232,7 +232,7 @@ async def main(args):
...
@@ -232,7 +232,7 @@ async def main(args):
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
None
,
chat_template
=
None
,
chat_template_content_format
=
"auto"
,
chat_template_content_format
=
"auto"
,
)
if
model_config
.
runner_type
==
"pooling
"
else
None
)
if
model_config
.
task
==
"embed
"
else
None
tracker
=
BatchProgressTracker
()
tracker
=
BatchProgressTracker
()
logger
.
info
(
"Reading batch from %s..."
,
args
.
input_file
)
logger
.
info
(
"Reading batch from %s..."
,
args
.
input_file
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
96ae75ad
...
@@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
"been registered"
)
from
e
"been registered"
)
from
e
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
diff_sampling_param
=
self
.
model_config
.
get_diff_sampling_param
()
if
diff_sampling_param
:
logger
.
info
(
"Overwriting default chat sampling param with: %s"
,
diff_sampling_param
)
async
def
create_chat_completion
(
async
def
create_chat_completion
(
self
,
self
,
...
@@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
default_max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
])
engine_prompt
[
"prompt_token_ids"
])
# Build default sampling params
default_sampling_params
=
(
self
.
model_config
.
get_diff_sampling_param
())
if
request
.
use_beam_search
:
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
sampling_params
=
request
.
to_beam_search_params
(
default_max_tokens
)
default_max_tokens
,
default_sampling_params
)
else
:
else
:
sampling_params
=
request
.
to_sampling_params
(
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
default_max_tokens
,
self
.
model_config
.
logits_processor_pattern
)
self
.
model_config
.
logits_processor_pattern
,
default_sampling_params
)
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
request_prompts
[
i
],
request_prompts
[
i
],
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
96ae75ad
...
@@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapters
=
prompt_adapters
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
diff_sampling_param
=
self
.
model_config
.
get_diff_sampling_param
()
if
diff_sampling_param
:
logger
.
info
(
"Overwriting default completion sampling param with: %s"
,
diff_sampling_param
)
async
def
create_completion
(
async
def
create_completion
(
self
,
self
,
...
@@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
default_max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
])
engine_prompt
[
"prompt_token_ids"
])
# Build default sampling params
default_sampling_params
=
(
self
.
model_config
.
get_diff_sampling_param
())
if
request
.
use_beam_search
:
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
sampling_params
=
request
.
to_beam_search_params
(
default_max_tokens
)
default_max_tokens
,
default_sampling_params
)
else
:
else
:
sampling_params
=
request
.
to_sampling_params
(
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
default_max_tokens
,
self
.
model_config
.
logits_processor_pattern
)
self
.
model_config
.
logits_processor_pattern
,
default_sampling_params
)
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
96ae75ad
...
@@ -40,36 +40,6 @@ def _get_embedding(
...
@@ -40,36 +40,6 @@ def _get_embedding(
assert_never
(
encoding_format
)
assert_never
(
encoding_format
)
def
request_output_to_embedding_response
(
final_res_batch
:
List
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"float"
,
"base64"
])
->
EmbeddingResponse
:
data
:
List
[
EmbeddingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
embedding_res
=
EmbeddingRequestOutput
.
from_base
(
final_res
)
prompt_token_ids
=
final_res
.
prompt_token_ids
embedding
=
_get_embedding
(
embedding_res
.
outputs
,
encoding_format
)
embedding_data
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
embedding
)
data
.
append
(
embedding_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
EmbeddingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
data
,
usage
=
usage
,
)
class
OpenAIServingEmbedding
(
OpenAIServing
):
class
OpenAIServingEmbedding
(
OpenAIServing
):
def
__init__
(
def
__init__
(
...
@@ -114,7 +84,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -114,7 +84,7 @@ class OpenAIServingEmbedding(OpenAIServing):
model_name
=
request
.
model
model_name
=
request
.
model
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
time
())
truncate_prompt_tokens
=
None
truncate_prompt_tokens
=
None
...
@@ -218,9 +188,13 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -218,9 +188,13 @@ class OpenAIServingEmbedding(OpenAIServing):
final_res_batch_checked
=
cast
(
List
[
PoolingRequestOutput
],
final_res_batch_checked
=
cast
(
List
[
PoolingRequestOutput
],
final_res_batch
)
final_res_batch
)
response
=
request_output_to_embedding_response
(
response
=
self
.
request_output_to_embedding_response
(
final_res_batch_checked
,
request_id
,
created_time
,
model_name
,
final_res_batch_checked
,
encoding_format
)
request_id
,
created_time
,
model_name
,
encoding_format
,
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
except
ValueError
as
e
:
...
@@ -228,3 +202,40 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -228,3 +202,40 @@ class OpenAIServingEmbedding(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
return
response
return
response
def
request_output_to_embedding_response
(
self
,
final_res_batch
:
List
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"float"
,
"base64"
],
)
->
EmbeddingResponse
:
items
:
List
[
EmbeddingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
embedding_res
=
EmbeddingRequestOutput
.
from_base
(
final_res
)
item
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
_get_embedding
(
embedding_res
.
outputs
,
encoding_format
),
)
prompt_token_ids
=
final_res
.
prompt_token_ids
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
EmbeddingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
vllm/entrypoints/openai/serving_pooling.py
0 → 100644
View file @
96ae75ad
import
asyncio
import
base64
import
time
from
typing
import
AsyncGenerator
,
Final
,
List
,
Literal
,
Optional
,
Union
,
cast
import
numpy
as
np
from
fastapi
import
Request
from
typing_extensions
import
assert_never
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
PoolingChatRequest
,
PoolingRequest
,
PoolingResponse
,
PoolingResponseData
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.utils
import
merge_async_iterators
logger
=
init_logger
(
__name__
)
def
_get_data
(
output
:
PoolingOutput
,
encoding_format
:
Literal
[
"float"
,
"base64"
],
)
->
Union
[
List
[
float
],
str
]:
if
encoding_format
==
"float"
:
return
output
.
data
.
tolist
()
elif
encoding_format
==
"base64"
:
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
pooling_bytes
=
np
.
array
(
output
.
data
,
dtype
=
"float32"
).
tobytes
()
return
base64
.
b64encode
(
pooling_bytes
).
decode
(
"utf-8"
)
assert_never
(
encoding_format
)
class
OpenAIServingPooling
(
OpenAIServing
):
def
__init__
(
self
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
base_model_paths
:
List
[
BaseModelPath
],
*
,
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
chat_template_content_format
:
ChatTemplateContentFormatOption
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
async
def
create_pooling
(
self
,
request
:
PoolingRequest
,
raw_request
:
Optional
[
Request
]
=
None
,
)
->
Union
[
PoolingResponse
,
ErrorResponse
]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
encoding_format
=
request
.
encoding_format
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
request
.
model
request_id
=
f
"pool-
{
self
.
_base_request_id
(
raw_request
)
}
"
created_time
=
int
(
time
.
time
())
truncate_prompt_tokens
=
None
if
request
.
truncate_prompt_tokens
is
not
None
:
if
request
.
truncate_prompt_tokens
<=
self
.
max_model_len
:
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
else
:
return
self
.
create_error_response
(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size."
)
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for pooling models"
)
if
isinstance
(
request
,
PoolingChatRequest
):
(
_
,
request_prompts
,
engine_prompts
,
)
=
await
self
.
_preprocess_chat
(
request
,
tokenizer
,
request
.
messages
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
# In pooling requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt
=
False
,
continue_final_message
=
False
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
else
:
(
request_prompts
,
engine_prompts
)
=
await
self
.
_preprocess_completion
(
request
,
tokenizer
,
request
.
input
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
except
ValueError
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
pooling_params
=
request
.
to_pooling_params
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
merge_async_iterators
(
*
generators
)
num_prompts
=
len
(
engine_prompts
)
# Non-streaming response
final_res_batch
:
List
[
Optional
[
PoolingRequestOutput
]]
final_res_batch
=
[
None
]
*
num_prompts
try
:
async
for
i
,
res
in
result_generator
:
final_res_batch
[
i
]
=
res
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
final_res_batch_checked
=
cast
(
List
[
PoolingRequestOutput
],
final_res_batch
)
response
=
self
.
request_output_to_pooling_response
(
final_res_batch_checked
,
request_id
,
created_time
,
model_name
,
encoding_format
,
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
response
def
request_output_to_pooling_response
(
self
,
final_res_batch
:
List
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"float"
,
"base64"
],
)
->
PoolingResponse
:
items
:
List
[
PoolingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
item
=
PoolingResponseData
(
index
=
idx
,
data
=
_get_data
(
final_res
.
outputs
,
encoding_format
),
)
prompt_token_ids
=
final_res
.
prompt_token_ids
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
PoolingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
vllm/entrypoints/openai/serving_score.py
View file @
96ae75ad
...
@@ -20,32 +20,6 @@ from vllm.utils import make_async, merge_async_iterators
...
@@ -20,32 +20,6 @@ from vllm.utils import make_async, merge_async_iterators
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
request_output_to_score_response
(
final_res_batch
:
List
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
)
->
ScoreResponse
:
data
:
List
[
ScoreResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
classify_res
=
ScoringRequestOutput
.
from_base
(
final_res
)
score_data
=
ScoreResponseData
(
index
=
idx
,
score
=
classify_res
.
outputs
.
score
)
data
.
append
(
score_data
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
ScoreResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
data
,
usage
=
usage
,
)
def
make_pairs
(
text_1
:
Union
[
List
[
str
],
str
],
text_2
:
Union
[
List
[
str
],
def
make_pairs
(
text_1
:
Union
[
List
[
str
],
str
],
text_2
:
Union
[
List
[
str
],
str
])
->
List
:
str
])
->
List
:
if
isinstance
(
text_1
,
(
str
,
dict
)):
if
isinstance
(
text_1
,
(
str
,
dict
)):
...
@@ -103,7 +77,7 @@ class OpenAIServingScores(OpenAIServing):
...
@@ -103,7 +77,7 @@ class OpenAIServingScores(OpenAIServing):
model_name
=
request
.
model
model_name
=
request
.
model
request_id
=
f
"score-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"score-
{
self
.
_base_request_id
(
raw_request
)
}
"
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
time
())
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
request_prompts
=
[]
request_prompts
=
[]
...
@@ -203,8 +177,12 @@ class OpenAIServingScores(OpenAIServing):
...
@@ -203,8 +177,12 @@ class OpenAIServingScores(OpenAIServing):
final_res_batch_checked
=
cast
(
List
[
PoolingRequestOutput
],
final_res_batch_checked
=
cast
(
List
[
PoolingRequestOutput
],
final_res_batch
)
final_res_batch
)
response
=
request_output_to_score_response
(
response
=
self
.
request_output_to_score_response
(
final_res_batch_checked
,
request_id
,
created_time
,
model_name
)
final_res_batch_checked
,
request_id
,
created_time
,
model_name
,
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
except
ValueError
as
e
:
...
@@ -212,3 +190,38 @@ class OpenAIServingScores(OpenAIServing):
...
@@ -212,3 +190,38 @@ class OpenAIServingScores(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
return
response
return
response
def
request_output_to_score_response
(
self
,
final_res_batch
:
List
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
ScoreResponse
:
items
:
List
[
ScoreResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
classify_res
=
ScoringRequestOutput
.
from_base
(
final_res
)
item
=
ScoreResponseData
(
index
=
idx
,
score
=
classify_res
.
outputs
.
score
,
)
prompt_token_ids
=
final_res
.
prompt_token_ids
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
ScoreResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
View file @
96ae75ad
...
@@ -35,13 +35,18 @@ class GraniteToolParser(ToolParser):
...
@@ -35,13 +35,18 @@ class GraniteToolParser(ToolParser):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
super
().
__init__
(
tokenizer
)
# for granite 3.0, the token `<|tool_call|>`
self
.
bot_token
=
"<|tool_call|>"
self
.
bot_token
=
"<|tool_call|>"
# for granite 3.1, the string `<tool_call>`
self
.
bot_string
=
"<tool_call>"
def
extract_tool_calls
(
def
extract_tool_calls
(
self
,
model_output
:
str
,
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
ExtractedToolCallInformation
:
request
:
ChatCompletionRequest
)
->
ExtractedToolCallInformation
:
# remove whitespace and the BOT token if it exists
stripped
=
model_output
.
strip
()
\
stripped
=
model_output
.
strip
().
removeprefix
(
self
.
bot_token
).
lstrip
()
.
removeprefix
(
self
.
bot_token
)
\
.
removeprefix
(
self
.
bot_string
)
\
.
lstrip
()
if
not
stripped
or
stripped
[
0
]
!=
'['
:
if
not
stripped
or
stripped
[
0
]
!=
'['
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
tool_calls
=
[],
...
@@ -91,6 +96,9 @@ class GraniteToolParser(ToolParser):
...
@@ -91,6 +96,9 @@ class GraniteToolParser(ToolParser):
if
current_text
[
start_idx
:].
startswith
(
self
.
bot_token
):
if
current_text
[
start_idx
:].
startswith
(
self
.
bot_token
):
start_idx
=
consume_space
(
start_idx
+
len
(
self
.
bot_token
),
start_idx
=
consume_space
(
start_idx
+
len
(
self
.
bot_token
),
current_text
)
current_text
)
if
current_text
[
start_idx
:].
startswith
(
self
.
bot_string
):
start_idx
=
consume_space
(
start_idx
+
len
(
self
.
bot_string
),
current_text
)
if
not
current_text
or
start_idx
>=
len
(
current_text
)
\
if
not
current_text
or
start_idx
>=
len
(
current_text
)
\
or
current_text
[
start_idx
]
!=
'['
:
or
current_text
[
start_idx
]
!=
'['
:
return
DeltaMessage
(
content
=
delta_text
)
return
DeltaMessage
(
content
=
delta_text
)
...
...
vllm/envs.py
View file @
96ae75ad
...
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
...
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
Fals
e
VLLM_USE_FLASHINFER_SAMPLER
:
Optional
[
bool
]
=
Non
e
VLLM_USE_FLASHINFER_REJECTION_SAMPLER
:
bool
=
False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER
:
bool
=
False
VLLM_FLASHINFER_FORCE_TENSOR_CORES
:
bool
=
False
VLLM_FLASHINFER_FORCE_TENSOR_CORES
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
...
@@ -308,7 +308,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -308,7 +308,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will use flashinfer sampler
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER"
:
"VLLM_USE_FLASHINFER_SAMPLER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_SAMPLER"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_USE_FLASHINFER_SAMPLER"
]))
if
"VLLM_USE_FLASHINFER_SAMPLER"
in
os
.
environ
else
None
,
# If set, vllm will force flashinfer to use tensor cores;
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
# otherwise will use heuristic based on model architecture.
...
...
vllm/executor/cpu_executor.py
View file @
96ae75ad
...
@@ -22,7 +22,7 @@ class CPUExecutor(ExecutorBase):
...
@@ -22,7 +22,7 @@ class CPUExecutor(ExecutorBase):
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"cpu"
assert
self
.
device_config
.
device_type
==
"cpu"
# Reminder: Please update docs/source/usage/compatibility_matrix.
rst
# Reminder: Please update docs/source/usage/compatibility_matrix.
md
# If the feature combo become valid
# If the feature combo become valid
assert
self
.
lora_config
is
None
,
"cpu backend doesn't support LoRA"
assert
self
.
lora_config
is
None
,
"cpu backend doesn't support LoRA"
...
...
vllm/executor/ray_gpu_executor.py
View file @
96ae75ad
...
@@ -123,6 +123,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -123,6 +123,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers.
# Create the workers.
driver_ip
=
get_ip
()
driver_ip
=
get_ip
()
workers
=
[]
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
continue
...
@@ -138,20 +139,30 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -138,20 +139,30 @@ class RayGPUExecutor(DistributedGPUExecutor):
scheduling_strategy
=
scheduling_strategy
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
vllm_config
=
self
.
vllm_config
)
)(
RayWorkerWrapper
).
remote
(
vllm_config
=
self
.
vllm_config
)
workers
.
append
(
worker
)
if
self
.
use_ray_spmd_worker
:
worker_ip_refs
=
[
self
.
workers
.
append
(
worker
)
worker
.
get_node_ip
.
remote
()
# type: ignore[attr-defined]
else
:
for
worker
in
workers
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
]
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
worker_ips
=
ray
.
get
(
worker_ip_refs
)
if
not
self
.
use_ray_spmd_worker
:
for
i
in
range
(
len
(
workers
)):
worker
=
workers
[
i
]
worker_ip
=
worker_ips
[
i
]
if
self
.
driver_dummy_worker
is
None
and
worker_ip
==
driver_ip
:
# If the worker is on the same node as the driver, we use it
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
self
.
driver_worker
=
RayWorkerWrapper
(
vllm_config
=
self
.
vllm_config
)
vllm_config
=
self
.
vllm_config
)
else
:
workers
.
pop
(
i
)
# Else, added to the list of workers.
worker_ips
.
pop
(
i
)
self
.
workers
.
append
(
worker
)
self
.
workers
=
workers
break
else
:
self
.
workers
=
workers
logger
.
debug
(
"workers: %s"
,
self
.
workers
)
logger
.
debug
(
"workers: %s"
,
self
.
workers
)
logger
.
debug
(
"driver_dummy_worker: %s"
,
self
.
driver_dummy_worker
)
logger
.
debug
(
"driver_dummy_worker: %s"
,
self
.
driver_dummy_worker
)
...
@@ -161,14 +172,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -161,14 +172,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
"adjusting the Ray placement group or running the driver on a "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
"GPU node."
)
worker_ips
=
[
ray
.
get
(
worker
.
get_node_ip
.
remote
())
# type: ignore[attr-defined]
for
worker
in
self
.
workers
]
ip_counts
:
Dict
[
str
,
int
]
=
{}
ip_counts
:
Dict
[
str
,
int
]
=
{}
for
ip
in
worker_ips
:
for
ip
in
worker_ips
:
ip_counts
[
ip
]
=
ip_counts
.
get
(
ip
,
0
)
+
1
ip_counts
[
ip
]
=
ip_counts
.
get
(
ip
,
0
)
+
1
worker_to_ip
=
dict
(
zip
(
self
.
workers
,
worker_ips
))
def
sort_by_driver_then_worker_ip
(
worker
):
def
sort_by_driver_then_worker_ip
(
worker
):
"""
"""
Sort the workers based on 3 properties:
Sort the workers based on 3 properties:
...
@@ -179,7 +188,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -179,7 +188,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
3. Finally, if the work is on a node with smaller IP address, it
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
should be placed first.
"""
"""
ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
ip
=
worker_to_ip
[
worker
]
return
(
ip
!=
driver_ip
,
ip_counts
[
ip
],
ip
)
return
(
ip
!=
driver_ip
,
ip_counts
[
ip
],
ip
)
# After sorting, the workers on the same node will be
# After sorting, the workers on the same node will be
...
...
vllm/inputs/__init__.py
View file @
96ae75ad
...
@@ -13,7 +13,7 @@ The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
...
@@ -13,7 +13,7 @@ The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
to dispatch data processing according to the target model.
to dispatch data processing according to the target model.
See also:
See also:
:ref:`input
_
processing
_
pipeline`
:ref:`input
-
processing
-
pipeline`
"""
"""
__all__
=
[
__all__
=
[
...
...
vllm/inputs/data.py
View file @
96ae75ad
...
@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
...
@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data.
Placeholder ranges for the multi-modal data.
"""
"""
multi_modal_hashes
:
NotRequired
[
List
[
str
]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
"""
Optional multi-modal processor kwargs to be forwarded to the
Optional multi-modal processor kwargs to be forwarded to the
...
@@ -177,6 +182,7 @@ def token_inputs(
...
@@ -177,6 +182,7 @@ def token_inputs(
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_inputs
:
Optional
[
"MultiModalKwargs"
]
=
None
,
multi_modal_inputs
:
Optional
[
"MultiModalKwargs"
]
=
None
,
multi_modal_hashes
:
Optional
[
List
[
str
]]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
TokenInputs
:
)
->
TokenInputs
:
...
@@ -191,6 +197,8 @@ def token_inputs(
...
@@ -191,6 +197,8 @@ def token_inputs(
inputs
[
"multi_modal_data"
]
=
multi_modal_data
inputs
[
"multi_modal_data"
]
=
multi_modal_data
if
multi_modal_inputs
is
not
None
:
if
multi_modal_inputs
is
not
None
:
inputs
[
"multi_modal_inputs"
]
=
multi_modal_inputs
inputs
[
"multi_modal_inputs"
]
=
multi_modal_inputs
if
multi_modal_hashes
is
not
None
:
inputs
[
"multi_modal_hashes"
]
=
multi_modal_hashes
if
multi_modal_placeholders
is
not
None
:
if
multi_modal_placeholders
is
not
None
:
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
if
mm_processor_kwargs
is
not
None
:
if
mm_processor_kwargs
is
not
None
:
...
@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
...
@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
assert_never
(
inputs
)
assert_never
(
inputs
)
@
cached_property
def
multi_modal_hashes
(
self
)
->
List
[
str
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"multi_modal_hashes"
,
[])
if
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"mm_hashes"
,
[])
assert_never
(
inputs
)
@
cached_property
@
cached_property
def
multi_modal_placeholders
(
self
)
->
"MultiModalPlaceholderDict"
:
def
multi_modal_placeholders
(
self
)
->
"MultiModalPlaceholderDict"
:
inputs
=
self
.
inputs
inputs
=
self
.
inputs
...
...
vllm/inputs/registry.py
View file @
96ae75ad
import
functools
import
functools
from
collections
import
UserDict
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
NamedTuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Mapping
,
NamedTuple
,
Optional
,
Protocol
,
Type
)
Optional
,
Protocol
,
Union
)
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
,
ProcessorMixin
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
typing_extensions
import
TypeVar
,
assert_never
from
typing_extensions
import
TypeVar
,
assert_never
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
...
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
P
=
TypeVar
(
"P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -38,24 +39,28 @@ class InputContext:
...
@@ -38,24 +39,28 @@ class InputContext:
model_config
:
"ModelConfig"
model_config
:
"ModelConfig"
"""The configuration of the model."""
"""The configuration of the model."""
def
get_hf_config
(
self
,
hf_config_type
:
Type
[
C
]
=
PretrainedConfig
)
->
C
:
def
get_hf_config
(
self
,
typ
:
Union
[
type
[
C
],
tuple
[
type
[
C
],
...]]
=
PretrainedConfig
,
/
,
)
->
C
:
"""
"""
Get the HuggingFace configuration
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
(:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type.
additionally checking its type.
Raises:
Raises:
TypeError: If the
model
is not of the specified type.
TypeError: If the
configuration
is not of the specified type.
"""
"""
hf_config
=
self
.
model_config
.
hf_config
hf_config
=
self
.
model_config
.
hf_config
if
not
isinstance
(
hf_config
,
hf_config_
typ
e
):
if
not
isinstance
(
hf_config
,
typ
):
raise
TypeError
(
"Invalid type of HuggingFace config. "
raise
TypeError
(
"Invalid type of HuggingFace config. "
f
"Expected type:
{
hf_config_
typ
e
}
, but "
f
"Expected type:
{
typ
}
, but "
f
"found type:
{
type
(
hf_config
)
}
"
)
f
"found type:
{
type
(
hf_config
)
}
"
)
return
hf_config
return
hf_config
def
get_hf_image_processor_config
(
self
)
->
D
ict
[
str
,
Any
]:
def
get_hf_image_processor_config
(
self
)
->
d
ict
[
str
,
Any
]:
"""
"""
Get the HuggingFace image processor configuration of the model.
Get the HuggingFace image processor configuration of the model.
"""
"""
...
@@ -74,18 +79,37 @@ class InputContext:
...
@@ -74,18 +79,37 @@ class InputContext:
return
mm_config
return
mm_config
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
def
get_hf_processor
(
self
,
typ
:
Union
[
type
[
P
],
tuple
[
type
[
P
],
...]]
=
ProcessorMixin
,
/
,
**
kwargs
:
object
,
)
->
P
:
"""
Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
base_kwargs
=
self
.
model_config
.
mm_processor_kwargs
base_kwargs
=
self
.
model_config
.
mm_processor_kwargs
if
base_kwargs
is
None
:
if
base_kwargs
is
None
:
base_kwargs
=
{}
base_kwargs
=
{}
merged_kwargs
=
{
**
base_kwargs
,
**
kwargs
}
merged_kwargs
=
{
**
base_kwargs
,
**
kwargs
}
return
cached_get_processor
(
hf_processor
=
cached_get_processor
(
self
.
model_config
.
model
,
self
.
model_config
.
model
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
**
merged_kwargs
,
**
merged_kwargs
,
)
)
if
not
isinstance
(
hf_processor
,
typ
):
raise
TypeError
(
"Invalid type of HuggingFace processor. "
f
"Expected type:
{
typ
}
, but "
f
"found type:
{
type
(
hf_processor
)
}
"
)
return
hf_processor
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -93,39 +117,55 @@ class InputProcessingContext(InputContext):
...
@@ -93,39 +117,55 @@ class InputProcessingContext(InputContext):
tokenizer
:
AnyTokenizer
tokenizer
:
AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
"""The tokenizer used to tokenize the inputs."""
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
def
get_hf_processor
(
base_kwargs
=
self
.
model_config
.
mm_processor_kwargs
self
,
if
base_kwargs
is
None
:
typ
:
Union
[
type
[
P
],
tuple
[
type
[
P
],
...]]
=
ProcessorMixin
,
base_kwargs
=
{}
/
,
**
kwargs
:
object
,
merged_kwargs
=
{
**
base_kwargs
,
**
kwargs
}
)
->
P
:
return
super
().
get_hf_processor
(
return
cached_get_processor
(
typ
,
self
.
model_config
.
model
,
tokenizer
=
self
.
tokenizer
,
tokenizer
=
self
.
tokenizer
,
# Override the tokenizer with ours
**
kwargs
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
**
merged_kwargs
,
)
)
def
resolve
_hf_processor
_call_kwargs
(
def
call
_hf_processor
(
self
,
self
,
hf_processor
:
ProcessorMixin
,
hf_processor
:
ProcessorMixin
,
prompt
:
str
,
processor_data
:
Mapping
[
str
,
object
],
inference_kwargs
:
Mapping
[
str
,
object
],
inference_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
object
]
:
)
->
BatchFeature
:
assert
callable
(
hf_processor
)
assert
callable
(
hf_processor
)
base_kwargs
=
self
.
model_config
.
mm_processor_kwargs
base_kwargs
=
self
.
model_config
.
mm_processor_kwargs
if
base_kwargs
is
None
:
if
base_kwargs
is
None
:
base_kwargs
=
{}
base_kwargs
=
{}
return
resolve_mm_processor_kwargs
(
merged_kwargs
=
resolve_mm_processor_kwargs
(
base_kwargs
,
base_kwargs
,
inference_kwargs
,
inference_kwargs
,
hf_processor
,
hf_processor
,
requires_kw_only
=
False
,
allow_var_kwargs
=
True
,
)
)
try
:
return
hf_processor
(
text
=
prompt
,
**
processor_data
,
**
merged_kwargs
,
return_tensors
=
"pt"
,
)
except
Exception
as
exc
:
data
=
dict
(
text
=
prompt
,
**
processor_data
)
msg
=
(
f
"Failed to apply
{
type
(
hf_processor
).
__name__
}
"
f
"on data=
{
data
}
with kwargs=
{
merged_kwargs
}
"
)
raise
RuntimeError
(
msg
)
from
exc
N
=
TypeVar
(
"N"
,
bound
=
T
ype
[
nn
.
Module
])
N
=
TypeVar
(
"N"
,
bound
=
t
ype
[
nn
.
Module
])
class
DummyData
(
NamedTuple
):
class
DummyData
(
NamedTuple
):
...
@@ -232,7 +272,7 @@ class InputRegistry:
...
@@ -232,7 +272,7 @@ class InputRegistry:
return
wrapper
return
wrapper
def
_get_dummy_data_factory
(
self
,
model_cls
:
T
ype
[
nn
.
Module
]):
def
_get_dummy_data_factory
(
self
,
model_cls
:
t
ype
[
nn
.
Module
]):
return
self
.
_dummy_factories_by_model_type
\
return
self
.
_dummy_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
...
@@ -257,7 +297,7 @@ class InputRegistry:
...
@@ -257,7 +297,7 @@ class InputRegistry:
return
wrapper
return
wrapper
def
_get_dummy_encoder_data_factory
(
self
,
model_cls
:
T
ype
[
nn
.
Module
]):
def
_get_dummy_encoder_data_factory
(
self
,
model_cls
:
t
ype
[
nn
.
Module
]):
return
self
.
_dummy_encoder_factories_by_model_type
\
return
self
.
_dummy_encoder_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
...
@@ -274,7 +314,7 @@ class InputRegistry:
...
@@ -274,7 +314,7 @@ class InputRegistry:
The model is identified by ``model_config``.
The model is identified by ``model_config``.
See also:
See also:
:ref:`enabling
_
multimodal
_
inputs`
:ref:`enabling
-
multimodal
-
inputs`
Note:
Note:
This should be called after
This should be called after
...
@@ -351,7 +391,7 @@ class InputRegistry:
...
@@ -351,7 +391,7 @@ class InputRegistry:
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
See also:
See also:
:ref:`input
_
processing
_
pipeline`
:ref:`input
-
processing
-
pipeline`
"""
"""
def
wrapper
(
model_cls
:
N
)
->
N
:
def
wrapper
(
model_cls
:
N
)
->
N
:
...
@@ -368,14 +408,14 @@ class InputRegistry:
...
@@ -368,14 +408,14 @@ class InputRegistry:
return
wrapper
return
wrapper
def
_get_model_input_processor
(
self
,
model_cls
:
T
ype
[
nn
.
Module
]):
def
_get_model_input_processor
(
self
,
model_cls
:
t
ype
[
nn
.
Module
]):
return
self
.
_input_processors_by_model_type
\
return
self
.
_input_processors_by_model_type
\
.
get
(
model_cls
,
self
.
_default_input_processor
)
.
get
(
model_cls
,
self
.
_default_input_processor
)
def
_ensure_mm_kwargs
(
def
_ensure_mm_kwargs
(
self
,
self
,
inputs
:
SingletonInputs
,
inputs
:
SingletonInputs
,
mm_processor_kwargs
:
D
ict
[
str
,
Any
],
mm_processor_kwargs
:
d
ict
[
str
,
Any
],
):
):
if
inputs
[
"type"
]
==
"token"
:
if
inputs
[
"type"
]
==
"token"
:
# In case the input processor for that model fails to set it
# In case the input processor for that model fails to set it
...
@@ -395,7 +435,7 @@ class InputRegistry:
...
@@ -395,7 +435,7 @@ class InputRegistry:
The model is identified by ``model_config``.
The model is identified by ``model_config``.
See also:
See also:
:ref:`input
_
processing
_
pipeline`
:ref:`input
-
processing
-
pipeline`
"""
"""
# Avoid circular import
# Avoid circular import
from
vllm.model_executor.model_loader
import
get_model_architecture
from
vllm.model_executor.model_loader
import
get_model_architecture
...
...
vllm/lora/layers.py
View file @
96ae75ad
...
@@ -425,8 +425,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -425,8 +425,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
if
self
.
base_layer
.
skip_bias_add
else
None
)
if
self
.
base_layer
.
skip_bias_add
else
None
)
return
output
,
output_bias
return
output
,
output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@
classmethod
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
def
can_replace_layer
(
cls
,
cls
,
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment