Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4851c202
Commit
4851c202
authored
Sep 13, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1' into v0.6.1-dev
parents
9b902f9e
3fd2b0d2
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
672 additions
and
241 deletions
+672
-241
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+25
-9
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+38
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+10
-7
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+77
-10
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+35
-19
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+78
-1
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+18
-7
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
+0
-1
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+5
-15
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+8
-19
vllm/envs.py
vllm/envs.py
+12
-0
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+20
-1
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+6
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+21
-2
vllm/lora/layers.py
vllm/lora/layers.py
+1
-1
vllm/lora/request.py
vllm/lora/request.py
+18
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+10
-4
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+219
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+20
-118
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+51
-24
No files found.
vllm/entrypoints/llm.py
View file @
4851c202
...
@@ -6,7 +6,8 @@ from tqdm import tqdm
...
@@ -6,7 +6,8 @@ from tqdm import tqdm
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
parse_chat_messages
)
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
...
@@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
...
@@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -55,7 +56,7 @@ class LLM:
...
@@ -55,7 +56,7 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq",
"squeezellm",
and "fp8" (experimental).
we support "awq", "gptq", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
not quantized and use `dtype` to determine the data type of
...
@@ -393,12 +394,21 @@ class LLM:
...
@@ -393,12 +394,21 @@ class LLM:
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
tokenizer
)
prompt
=
apply_chat_template
(
prompt
:
Union
[
str
,
List
[
int
]]
tokenizer
,
if
isinstance
(
tokenizer
,
MistralTokenizer
):
conversation
,
prompt
=
apply_mistral_chat_template
(
chat_template
=
chat_template
,
tokenizer
,
add_generation_prompt
=
add_generation_prompt
,
messages
=
messages
,
)
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
inputs
:
PromptInputs
inputs
:
PromptInputs
if
is_list_of
(
prompt
,
int
):
if
is_list_of
(
prompt
,
int
):
...
@@ -560,6 +570,12 @@ class LLM:
...
@@ -560,6 +570,12 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
EmbeddingRequestOutput
)
return
LLMEngine
.
validate_outputs
(
outputs
,
EmbeddingRequestOutput
)
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
def
stop_profile
(
self
)
->
None
:
self
.
llm_engine
.
stop_profile
()
# LEGACY
# LEGACY
def
_convert_v1_inputs
(
def
_convert_v1_inputs
(
self
,
self
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
4851c202
...
@@ -35,11 +35,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -35,11 +35,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeResponse
,
DetokenizeResponse
,
EmbeddingRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
EmbeddingResponse
,
ErrorResponse
,
LoadLoraAdapterRequest
,
TokenizeRequest
,
TokenizeRequest
,
TokenizeResponse
)
TokenizeResponse
,
# yapf: enable
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
...
@@ -343,6 +345,40 @@ if envs.VLLM_TORCH_PROFILER_DIR:
...
@@ -343,6 +345,40 @@ if envs.VLLM_TORCH_PROFILER_DIR:
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
logger
.
warning
(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
return
Response
(
status_code
=
200
,
content
=
response
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
return
Response
(
status_code
=
200
,
content
=
response
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
app
.
include_router
(
router
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
4851c202
...
@@ -713,13 +713,6 @@ class DeltaToolCall(OpenAIBaseModel):
...
@@ -713,13 +713,6 @@ class DeltaToolCall(OpenAIBaseModel):
function
:
Optional
[
DeltaFunctionCall
]
=
None
function
:
Optional
[
DeltaFunctionCall
]
=
None
# the initial delta that gets sent once a new tool call is started;
class
InitialDeltaToolCall
(
DeltaToolCall
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-tool-
{
random_uuid
()
}
"
)
type
:
Literal
[
"function"
]
=
"function"
index
:
int
class
ExtractedToolCallInformation
(
BaseModel
):
class
ExtractedToolCallInformation
(
BaseModel
):
# indicate if tools were called
# indicate if tools were called
tools_called
:
bool
tools_called
:
bool
...
@@ -878,3 +871,13 @@ class DetokenizeRequest(OpenAIBaseModel):
...
@@ -878,3 +871,13 @@ class DetokenizeRequest(OpenAIBaseModel):
class
DetokenizeResponse
(
OpenAIBaseModel
):
class
DetokenizeResponse
(
OpenAIBaseModel
):
prompt
:
str
prompt
:
str
class
LoadLoraAdapterRequest
(
BaseModel
):
lora_name
:
str
lora_path
:
str
class
UnloadLoraAdapterRequest
(
BaseModel
):
lora_name
:
str
lora_int_id
:
Optional
[
int
]
=
Field
(
default
=
None
)
vllm/entrypoints/openai/run_batch.py
View file @
4851c202
import
asyncio
import
asyncio
from
http
import
HTTPStatus
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Awaitable
,
Callable
,
List
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
import
aiohttp
import
aiohttp
import
torch
from
prometheus_client
import
start_http_server
from
prometheus_client
import
start_http_server
from
tqdm
import
tqdm
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
@@ -78,6 +81,38 @@ def parse_args():
...
@@ -78,6 +81,38 @@ def parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT
=
"{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]
\n
"
# noqa: E501
class
BatchProgressTracker
:
def
__init__
(
self
):
self
.
_total
=
0
self
.
_pbar
:
Optional
[
tqdm
]
=
None
def
submitted
(
self
):
self
.
_total
+=
1
def
completed
(
self
):
if
self
.
_pbar
:
self
.
_pbar
.
update
()
def
pbar
(
self
)
->
tqdm
:
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
self
.
_pbar
=
tqdm
(
total
=
self
.
_total
,
unit
=
"req"
,
desc
=
"Running batch"
,
mininterval
=
5
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
)
return
self
.
_pbar
async
def
read_file
(
path_or_url
:
str
)
->
str
:
async
def
read_file
(
path_or_url
:
str
)
->
str
:
if
path_or_url
.
startswith
(
"http://"
)
or
path_or_url
.
startswith
(
"https://"
):
if
path_or_url
.
startswith
(
"http://"
)
or
path_or_url
.
startswith
(
"https://"
):
async
with
aiohttp
.
ClientSession
()
as
session
,
\
async
with
aiohttp
.
ClientSession
()
as
session
,
\
...
@@ -101,8 +136,28 @@ async def write_file(path_or_url: str, data: str) -> None:
...
@@ -101,8 +136,28 @@ async def write_file(path_or_url: str, data: str) -> None:
f
.
write
(
data
)
f
.
write
(
data
)
def
make_error_request_output
(
request
:
BatchRequestInput
,
error_msg
:
str
)
->
BatchRequestOutput
:
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
response
=
BatchResponseData
(
status_code
=
HTTPStatus
.
BAD_REQUEST
,
request_id
=
f
"vllm-batch-
{
random_uuid
()
}
"
,
),
error
=
error_msg
,
)
return
batch_output
async
def
make_async_error_request_output
(
request
:
BatchRequestInput
,
error_msg
:
str
)
->
BatchRequestOutput
:
return
make_error_request_output
(
request
,
error_msg
)
async
def
run_request
(
serving_engine_func
:
Callable
,
async
def
run_request
(
serving_engine_func
:
Callable
,
request
:
BatchRequestInput
)
->
BatchRequestOutput
:
request
:
BatchRequestInput
,
tracker
:
BatchProgressTracker
)
->
BatchRequestOutput
:
response
=
await
serving_engine_func
(
request
.
body
)
response
=
await
serving_engine_func
(
request
.
body
)
if
isinstance
(
response
,
(
ChatCompletionResponse
,
EmbeddingResponse
)):
if
isinstance
(
response
,
(
ChatCompletionResponse
,
EmbeddingResponse
)):
...
@@ -123,8 +178,10 @@ async def run_request(serving_engine_func: Callable,
...
@@ -123,8 +178,10 @@ async def run_request(serving_engine_func: Callable,
error
=
response
,
error
=
response
,
)
)
else
:
else
:
raise
ValueError
(
"Request must not be sent in stream mode"
)
batch_output
=
make_error_request_output
(
request
,
error_msg
=
"Request must not be sent in stream mode"
)
tracker
.
completed
()
return
batch_output
return
batch_output
...
@@ -164,6 +221,9 @@ async def main(args):
...
@@ -164,6 +221,9 @@ async def main(args):
request_logger
=
request_logger
,
request_logger
=
request_logger
,
)
)
tracker
=
BatchProgressTracker
()
logger
.
info
(
"Reading batch from %s..."
,
args
.
input_file
)
# Submit all requests in the file to the engine "concurrently".
# Submit all requests in the file to the engine "concurrently".
response_futures
:
List
[
Awaitable
[
BatchRequestOutput
]]
=
[]
response_futures
:
List
[
Awaitable
[
BatchRequestOutput
]]
=
[]
for
request_json
in
(
await
read_file
(
args
.
input_file
)).
strip
().
split
(
"
\n
"
):
for
request_json
in
(
await
read_file
(
args
.
input_file
)).
strip
().
split
(
"
\n
"
):
...
@@ -178,16 +238,23 @@ async def main(args):
...
@@ -178,16 +238,23 @@ async def main(args):
if
request
.
url
==
"/v1/chat/completions"
:
if
request
.
url
==
"/v1/chat/completions"
:
response_futures
.
append
(
response_futures
.
append
(
run_request
(
openai_serving_chat
.
create_chat_completion
,
run_request
(
openai_serving_chat
.
create_chat_completion
,
request
))
request
,
tracker
))
tracker
.
submitted
()
elif
request
.
url
==
"/v1/embeddings"
:
elif
request
.
url
==
"/v1/embeddings"
:
response_futures
.
append
(
response_futures
.
append
(
run_request
(
openai_serving_embedding
.
create_embedding
,
run_request
(
openai_serving_embedding
.
create_embedding
,
request
,
request
))
tracker
))
tracker
.
submitted
()
else
:
else
:
raise
ValueError
(
"Only /v1/chat/completions and /v1/embeddings are"
response_futures
.
append
(
"supported in the batch endpoint."
)
make_async_error_request_output
(
request
,
responses
=
await
asyncio
.
gather
(
*
response_futures
)
error_msg
=
"Only /v1/chat/completions and "
"/v1/embeddings are supported in the batch endpoint."
,
))
with
tracker
.
pbar
():
responses
=
await
asyncio
.
gather
(
*
response_futures
)
output_buffer
=
StringIO
()
output_buffer
=
StringIO
()
for
response
in
responses
:
for
response
in
responses
:
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
4851c202
...
@@ -11,7 +11,8 @@ from fastapi import Request
...
@@ -11,7 +11,8 @@ from fastapi import Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
load_chat_template
,
parse_chat_messages_futures
)
parse_chat_messages_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
...
@@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
iterate_with_cancellation
,
random_uuid
from
vllm.utils
import
iterate_with_cancellation
,
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -121,15 +122,27 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -121,15 +122,27 @@ class OpenAIServingChat(OpenAIServing):
tool
.
model_dump
()
for
tool
in
request
.
tools
tool
.
model_dump
()
for
tool
in
request
.
tools
]
]
prompt
=
apply_chat_template
(
prompt
:
Union
[
str
,
List
[
int
]]
tokenizer
,
if
isinstance
(
tokenizer
,
MistralTokenizer
):
conversation
=
conversation
,
prompt
=
apply_mistral_chat_template
(
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
tokenizer
,
add_generation_prompt
=
request
.
add_generation_prompt
,
messages
=
request
.
messages
,
tools
=
tool_dicts
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
documents
=
request
.
documents
,
add_generation_prompt
=
request
.
add_generation_prompt
,
**
(
request
.
chat_template_kwargs
or
{}),
tools
=
tool_dicts
,
)
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -271,9 +284,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -271,9 +284,13 @@ class OpenAIServingChat(OpenAIServing):
# NOTE num_choices defaults to 1 so this usually executes
# NOTE num_choices defaults to 1 so this usually executes
# once per request
# once per request
for
i
in
range
(
num_choices
):
for
i
in
range
(
num_choices
):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
delta
=
DeltaMessage
(
role
=
role
,
content
=
""
,
),
logprobs
=
None
,
logprobs
=
None
,
finish_reason
=
None
)
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
...
@@ -303,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -303,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# Send response to echo the input portion of the
# last message
# last message
if
request
.
echo
:
if
request
.
echo
:
last_msg_content
:
Optional
[
str
]
=
""
last_msg_content
:
str
=
""
if
conversation
and
conversation
[
-
1
].
get
(
if
conversation
and
"content"
in
conversation
[
"content"
)
and
conversation
[
-
1
].
get
(
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
last_msg_content
=
conversation
[
-
1
][
"content"
]
if
last_msg_content
:
if
last_msg_content
:
for
i
in
range
(
num_choices
):
for
i
in
range
(
num_choices
):
...
@@ -655,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -655,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
if
request
.
echo
:
if
request
.
echo
:
last_msg_content
=
""
last_msg_content
=
""
if
conversation
and
conversation
[
-
1
]
.
get
(
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
"content"
)
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
for
choice
in
choices
:
for
choice
in
choices
:
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
4851c202
...
@@ -16,11 +16,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -16,11 +16,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
ErrorResponse
,
EmbeddingRequest
,
ErrorResponse
,
LoadLoraAdapterRequest
,
ModelCard
,
ModelList
,
ModelCard
,
ModelList
,
ModelPermission
,
ModelPermission
,
TokenizeChatRequest
,
TokenizeChatRequest
,
TokenizeCompletionRequest
,
TokenizeCompletionRequest
,
TokenizeRequest
)
TokenizeRequest
,
UnloadLoraAdapterRequest
)
# yapf: enable
# yapf: enable
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -32,6 +34,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -32,6 +34,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
AtomicCounter
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -78,6 +81,7 @@ class OpenAIServing:
...
@@ -78,6 +81,7 @@ class OpenAIServing:
self
.
served_model_names
=
served_model_names
self
.
served_model_names
=
served_model_names
self
.
lora_id_counter
=
AtomicCounter
(
0
)
self
.
lora_requests
=
[]
self
.
lora_requests
=
[]
if
lora_modules
is
not
None
:
if
lora_modules
is
not
None
:
self
.
lora_requests
=
[
self
.
lora_requests
=
[
...
@@ -403,3 +407,76 @@ class OpenAIServing:
...
@@ -403,3 +407,76 @@ class OpenAIServing:
if
logprob
.
decoded_token
is
not
None
:
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
logprob
.
decoded_token
return
tokenizer
.
decode
(
token_id
)
return
tokenizer
.
decode
(
token_id
)
async
def
_check_load_lora_adapter_request
(
self
,
request
:
LoadLoraAdapterRequest
)
->
Optional
[
ErrorResponse
]:
# Check if both 'lora_name' and 'lora_path' are provided
if
not
request
.
lora_name
or
not
request
.
lora_path
:
return
self
.
create_error_response
(
message
=
"Both 'lora_name' and 'lora_path' must be provided."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# Check if the lora adapter with the given name already exists
if
any
(
lora_request
.
lora_name
==
request
.
lora_name
for
lora_request
in
self
.
lora_requests
):
return
self
.
create_error_response
(
message
=
f
"The lora adapter '
{
request
.
lora_name
}
' has already been"
"loaded."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
return
None
async
def
_check_unload_lora_adapter_request
(
self
,
request
:
UnloadLoraAdapterRequest
)
->
Optional
[
ErrorResponse
]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if
not
request
.
lora_name
and
not
request
.
lora_int_id
:
return
self
.
create_error_response
(
message
=
"either 'lora_name' and 'lora_int_id' needs to be provided."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# Check if the lora adapter with the given name exists
if
not
any
(
lora_request
.
lora_name
==
request
.
lora_name
for
lora_request
in
self
.
lora_requests
):
return
self
.
create_error_response
(
message
=
f
"The lora adapter '
{
request
.
lora_name
}
' cannot be found."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
return
None
async
def
load_lora_adapter
(
self
,
request
:
LoadLoraAdapterRequest
)
->
Union
[
ErrorResponse
,
str
]:
error_check_ret
=
await
self
.
_check_load_lora_adapter_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
lora_name
,
lora_path
=
request
.
lora_name
,
request
.
lora_path
unique_id
=
self
.
lora_id_counter
.
inc
(
1
)
self
.
lora_requests
.
append
(
LoRARequest
(
lora_name
=
lora_name
,
lora_int_id
=
unique_id
,
lora_path
=
lora_path
))
return
f
"Success: LoRA adapter '
{
lora_name
}
' added successfully."
async
def
unload_lora_adapter
(
self
,
request
:
UnloadLoraAdapterRequest
)
->
Union
[
ErrorResponse
,
str
]:
error_check_ret
=
await
self
.
_check_unload_lora_adapter_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
lora_name
=
request
.
lora_name
self
.
lora_requests
=
[
lora_request
for
lora_request
in
self
.
lora_requests
if
lora_request
.
lora_name
!=
lora_name
]
return
f
"Success: LoRA adapter '
{
lora_name
}
' removed successfully."
vllm/entrypoints/openai/serving_tokenization.py
View file @
4851c202
...
@@ -2,7 +2,8 @@ from typing import List, Optional, Union
...
@@ -2,7 +2,8 @@ from typing import List, Optional, Union
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
load_chat_template
,
parse_chat_messages_futures
)
parse_chat_messages_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
...
@@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
MistralTokenizer
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
request
,
TokenizeChatRequest
):
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
model_config
=
self
.
model_config
...
@@ -77,12 +80,20 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -77,12 +80,20 @@ class OpenAIServingTokenization(OpenAIServing):
logger
.
warning
(
logger
.
warning
(
"Multi-modal inputs are ignored during tokenization"
)
"Multi-modal inputs are ignored during tokenization"
)
prompt
=
apply_chat_template
(
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tokenizer
,
prompt
=
apply_mistral_chat_template
(
conversation
=
conversation
,
tokenizer
,
chat_template
=
self
.
chat_template
,
messages
=
request
.
messages
,
add_generation_prompt
=
request
.
add_generation_prompt
,
chat_template
=
self
.
chat_template
,
)
add_generation_prompt
=
request
.
add_generation_prompt
,
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
)
else
:
else
:
prompt
=
request
.
prompt
prompt
=
request
.
prompt
...
...
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
View file @
4851c202
...
@@ -20,7 +20,6 @@ class ToolParser:
...
@@ -20,7 +20,6 @@ class ToolParser:
# the index of the tool call that is currently being parsed
# the index of the tool call that is currently being parsed
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[]
self
.
streamed_args_for_tool
:
List
[
str
]
=
[]
self
.
model_tokenizer
=
tokenizer
self
.
model_tokenizer
=
tokenizer
...
...
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
View file @
4851c202
...
@@ -8,14 +8,14 @@ from partial_json_parser.core.options import Allow
...
@@ -8,14 +8,14 @@ from partial_json_parser.core.options import Allow
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
DeltaToolCall
,
ExtractedToolCallInformation
,
ExtractedToolCallInformation
,
FunctionCall
,
FunctionCall
,
ToolCall
)
InitialDeltaToolCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
)
ToolParser
)
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
extract_intermediate_diff
)
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -34,7 +34,6 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -34,7 +34,6 @@ class Hermes2ProToolParser(ToolParser):
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
]
# map what has been streamed for each tool so far to a list
...
@@ -168,7 +167,6 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -168,7 +167,6 @@ class Hermes2ProToolParser(ToolParser):
# set cursors and state appropriately
# set cursors and state appropriately
self
.
current_tool_id
+=
1
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"Starting on a new tool %s"
,
self
.
current_tool_id
)
logger
.
debug
(
"Starting on a new tool %s"
,
self
.
current_tool_id
)
...
@@ -218,24 +216,16 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -218,24 +216,16 @@ class Hermes2ProToolParser(ToolParser):
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
return
None
return
None
# case - we haven't sent the initial delta with the tool call ID
# (it will be sent)
if
not
self
.
current_tool_initial_sent
:
self
.
current_tool_initial_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
InitialDeltaToolCall
(
index
=
self
.
current_tool_id
).
model_dump
(
exclude_none
=
True
)
])
# case - we haven't sent the tool name yet. If it's available, send
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
# it. otherwise, wait until it's available.
el
if
not
self
.
current_tool_name_sent
:
if
not
self
.
current_tool_name_sent
:
function_name
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"name"
)
function_name
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
if
function_name
:
self
.
current_tool_name_sent
=
True
self
.
current_tool_name_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
f
"chatcmpl-tool-
{
random_uuid
()
}
"
,
function
=
DeltaFunctionCall
(
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
exclude_none
=
True
))
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
4851c202
...
@@ -8,14 +8,14 @@ from partial_json_parser.core.options import Allow
...
@@ -8,14 +8,14 @@ from partial_json_parser.core.options import Allow
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
DeltaToolCall
,
ExtractedToolCallInformation
,
ExtractedToolCallInformation
,
FunctionCall
,
FunctionCall
,
ToolCall
)
InitialDeltaToolCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
)
ToolParser
)
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
extract_intermediate_diff
)
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -25,7 +25,7 @@ class MistralToolParser(ToolParser):
...
@@ -25,7 +25,7 @@ class MistralToolParser(ToolParser):
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.
examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser
g
mistral are all set
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
"""
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
...
@@ -42,7 +42,6 @@ class MistralToolParser(ToolParser):
...
@@ -42,7 +42,6 @@ class MistralToolParser(ToolParser):
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token
=
"[TOOL_CALLS]"
...
@@ -91,7 +90,6 @@ class MistralToolParser(ToolParser):
...
@@ -91,7 +90,6 @@ class MistralToolParser(ToolParser):
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
"Error in extracting tool call from response: %s"
,
e
)
logger
.
error
(
"Error in extracting tool call from response: %s"
,
e
)
print
(
"ERROR"
,
e
)
# return information to just treat the tool call as regular JSON
# return information to just treat the tool call as regular JSON
return
ExtractedToolCallInformation
(
tools_called
=
False
,
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
tool_calls
=
[],
...
@@ -109,7 +107,7 @@ class MistralToolParser(ToolParser):
...
@@ -109,7 +107,7 @@ class MistralToolParser(ToolParser):
# if the tool call token is not in the tokens generated so far, append
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
# output to contents since it's not a tool
if
self
.
bot_token
_id
not
in
current_t
oken_ids
:
if
self
.
bot_token
not
in
current_t
ext
:
return
DeltaMessage
(
content
=
delta_text
)
return
DeltaMessage
(
content
=
delta_text
)
# if the tool call token ID IS in the tokens generated so far, that
# if the tool call token ID IS in the tokens generated so far, that
...
@@ -134,7 +132,7 @@ class MistralToolParser(ToolParser):
...
@@ -134,7 +132,7 @@ class MistralToolParser(ToolParser):
# replace BOT token with empty string, and convert single quotes
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
# quotes instead of double for tool calls
parsable_arr
=
current_text
.
split
(
self
.
bot_token
)[
1
]
parsable_arr
=
current_text
.
split
(
self
.
bot_token
)[
-
1
]
# tool calls are generated in an array, so do partial JSON
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
# parsing on the entire array
...
@@ -186,31 +184,22 @@ class MistralToolParser(ToolParser):
...
@@ -186,31 +184,22 @@ class MistralToolParser(ToolParser):
# re-set stuff pertaining to progress in the current tool
# re-set stuff pertaining to progress in the current tool
self
.
current_tool_id
=
len
(
tool_call_arr
)
-
1
self
.
current_tool_id
=
len
(
tool_call_arr
)
-
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"starting on new tool %d"
,
self
.
current_tool_id
)
logger
.
debug
(
"starting on new tool %d"
,
self
.
current_tool_id
)
return
delta
return
delta
# case: update an existing tool - this is handled below
# case: update an existing tool - this is handled below
# if the current tool initial data incl. the id, type=function
# and idx not sent, send that
if
not
self
.
current_tool_initial_sent
:
self
.
current_tool_initial_sent
=
True
delta
=
DeltaMessage
(
tool_calls
=
[
InitialDeltaToolCall
(
index
=
self
.
current_tool_id
).
model_dump
(
exclude_none
=
True
)
])
# if the current tool name hasn't been sent, send if available
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
# - otherwise send nothing
el
if
not
self
.
current_tool_name_sent
:
if
not
self
.
current_tool_name_sent
:
function_name
=
current_tool_call
.
get
(
"name"
)
function_name
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
f
"chatcmpl-tool-
{
random_uuid
()
}
"
,
function
=
DeltaFunctionCall
(
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
exclude_none
=
True
))
...
...
vllm/envs.py
View file @
4851c202
...
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
...
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_ENGINE_USE_RAY
:
bool
=
False
VLLM_ALLOW_ENGINE_USE_RAY
:
bool
=
False
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -226,6 +227,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -226,6 +227,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
"LOCAL_RANK"
:
"LOCAL_RANK"
:
...
@@ -433,6 +439,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -433,6 +439,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vLLM will use Triton implementations of AWQ.
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ"
:
"VLLM_USE_TRITON_AWQ"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_AWQ"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_AWQ"
,
"0"
))),
# If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/executor/cpu_executor.py
View file @
4851c202
...
@@ -5,7 +5,8 @@ from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
...
@@ -5,7 +5,8 @@ from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
ResultHandler
,
WorkerMonitor
)
...
@@ -60,6 +61,8 @@ class CPUExecutor(ExecutorBase):
...
@@ -60,6 +61,8 @@ class CPUExecutor(ExecutorBase):
self
.
cache_config
=
_verify_and_get_cache_config
(
self
.
cache_config
)
self
.
cache_config
=
_verify_and_get_cache_config
(
self
.
cache_config
)
self
.
scheduler_config
=
_verify_and_get_scheduler_config
(
self
.
scheduler_config
=
_verify_and_get_scheduler_config
(
self
.
scheduler_config
)
self
.
scheduler_config
)
self
.
parallel_config
=
_verify_and_get_parallel_config
(
self
.
parallel_config
)
# Multiprocessing-based executor does not support multi-node setting.
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# Since it only works for single node, we can use the loopback address
...
@@ -296,6 +299,12 @@ class CPUExecutor(ExecutorBase):
...
@@ -296,6 +299,12 @@ class CPUExecutor(ExecutorBase):
for
result
in
parallel_worker_tasks
:
for
result
in
parallel_worker_tasks
:
result
.
get
()
result
.
get
()
def
start_profile
(
self
)
->
None
:
self
.
driver_method_invoker
(
self
.
driver_worker
,
"start_profile"
)
def
stop_profile
(
self
)
->
None
:
self
.
driver_method_invoker
(
self
.
driver_worker
,
"stop_profile"
)
class
CPUExecutorAsync
(
CPUExecutor
,
ExecutorAsyncBase
):
class
CPUExecutorAsync
(
CPUExecutor
,
ExecutorAsyncBase
):
...
@@ -353,6 +362,16 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
...
@@ -353,6 +362,16 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
return
config
return
config
def
_verify_and_get_parallel_config
(
config
:
ParallelConfig
)
->
ParallelConfig
:
if
(
config
.
distributed_executor_backend
is
not
None
and
config
.
distributed_executor_backend
!=
"mp"
):
logger
.
warning
(
"%s is not supported on CPU, fallback to mp distributed executor "
"backend."
,
config
.
distributed_executor_backend
)
config
.
distributed_executor_backend
=
"mp"
return
config
def
_driver_method_invoker
(
driver
,
method
:
str
,
*
args
,
**
kwargs
):
def
_driver_method_invoker
(
driver
,
method
:
str
,
*
args
,
**
kwargs
):
return
getattr
(
driver
,
method
)(
*
args
,
**
kwargs
)
return
getattr
(
driver
,
method
)(
*
args
,
**
kwargs
)
...
...
vllm/executor/gpu_executor.py
View file @
4851c202
...
@@ -169,6 +169,12 @@ class GPUExecutor(ExecutorBase):
...
@@ -169,6 +169,12 @@ class GPUExecutor(ExecutorBase):
# it's running.
# it's running.
return
return
def
start_profile
(
self
)
->
None
:
self
.
driver_worker
.
start_profile
()
def
stop_profile
(
self
)
->
None
:
self
.
driver_worker
.
stop_profile
()
class
GPUExecutorAsync
(
GPUExecutor
,
ExecutorAsyncBase
):
class
GPUExecutorAsync
(
GPUExecutor
,
ExecutorAsyncBase
):
...
...
vllm/executor/ray_gpu_executor.py
View file @
4851c202
...
@@ -242,6 +242,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -242,6 +242,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
VLLM_INSTANCE_ID
,
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
str
(
envs
.
VLLM_TRACE_FUNCTION
),
**
({
"VLLM_ATTENTION_BACKEND"
:
envs
.
VLLM_ATTENTION_BACKEND
}
if
envs
.
VLLM_ATTENTION_BACKEND
is
not
None
else
{})
},
)
for
(
node_id
,
_
)
in
worker_node_and_gpu_ids
]
},
)
for
(
node_id
,
_
)
in
worker_node_and_gpu_ids
]
self
.
_env_vars_for_all_workers
=
(
self
.
_env_vars_for_all_workers
=
(
...
@@ -427,18 +430,34 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -427,18 +430,34 @@ class RayGPUExecutor(DistributedGPUExecutor):
async_run_remote_workers_only to complete."""
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
ray
.
get
(
parallel_worker_tasks
)
def
_c
ompiled
_ray_dag
(
self
,
enable_asyncio
:
bool
):
def
_c
heck
_ray_
a
dag
_installation
(
self
):
import
pkg_resources
import
pkg_resources
from
packaging
import
version
from
packaging
import
version
required_version
=
version
.
parse
(
"2.3
2
"
)
required_version
=
version
.
parse
(
"2.3
5
"
)
current_version
=
version
.
parse
(
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
import
importlib.util
adag_spec
=
importlib
.
util
.
find_spec
(
"ray.experimental.compiled_dag_ref"
)
if
adag_spec
is
None
:
raise
ValueError
(
"Ray accelerated DAG is not installed. "
"Run `pip install ray[adag]` to install it."
)
cupy_spec
=
importlib
.
util
.
find_spec
(
"cupy"
)
if
cupy_spec
is
None
and
envs
.
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL
:
raise
ValueError
(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation."
)
def
_compiled_ray_dag
(
self
,
enable_asyncio
:
bool
):
assert
self
.
parallel_config
.
use_ray
assert
self
.
parallel_config
.
use_ray
self
.
_check_ray_adag_installation
()
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.experimental.channel.torch_tensor_type
import
TorchTensorType
from
ray.experimental.channel.torch_tensor_type
import
TorchTensorType
...
...
vllm/lora/layers.py
View file @
4851c202
...
@@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
...
@@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear
# unquantizedLinear
if
hasattr
(
base_layer
,
"weight"
):
if
hasattr
(
base_layer
,
"weight"
):
return
base_layer
.
weight
.
device
return
base_layer
.
weight
.
device
# GPTQ/AWQ
/SqueezeLLM
# GPTQ/AWQ
elif
hasattr
(
base_layer
,
"qweight"
):
elif
hasattr
(
base_layer
,
"qweight"
):
return
base_layer
.
qweight
.
device
return
base_layer
.
qweight
.
device
# marlin
# marlin
...
...
vllm/lora/request.py
View file @
4851c202
...
@@ -28,7 +28,6 @@ class LoRARequest(
...
@@ -28,7 +28,6 @@ class LoRARequest(
lora_path
:
str
=
""
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
long_lora_max_len
:
Optional
[
int
]
=
None
__hash__
=
AdapterRequest
.
__hash__
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__struct_fields__
:
if
'lora_local_path'
in
self
.
__struct_fields__
:
...
@@ -75,3 +74,21 @@ class LoRARequest(
...
@@ -75,3 +74,21 @@ class LoRARequest(
DeprecationWarning
,
DeprecationWarning
,
stacklevel
=
2
)
stacklevel
=
2
)
self
.
lora_path
=
value
self
.
lora_path
=
value
def
__eq__
(
self
,
value
:
object
)
->
bool
:
"""
Overrides the equality method to compare LoRARequest
instances based on lora_name. This allows for identification
and comparison lora adapter across engines.
"""
return
isinstance
(
value
,
self
.
__class__
)
and
self
.
lora_name
==
value
.
lora_name
def
__hash__
(
self
)
->
int
:
"""
Overrides the hash method to hash LoRARequest instances
based on lora_name. This ensures that LoRARequest instances
can be used in hash-based collections such as sets and dictionaries,
identified by their names across engines.
"""
return
hash
(
self
.
lora_name
)
vllm/model_executor/layers/fused_moe/__init__.py
View file @
4851c202
...
@@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
...
@@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
]
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
,
]
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_
marlin_
moe
,
fused_
moe
,
fused_topk
,
fused_experts
,
fused_moe
,
fused_
topk
,
get_config_file_name
,
get_config_file_name
,
grouped_topk
)
grouped_topk
)
__all__
+=
[
__all__
+=
[
"fused_marlin_moe"
,
"fused_marlin_moe"
,
"single_marlin_moe"
,
"fused_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_topk"
,
"fused_experts"
,
"fused_experts"
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
0 → 100644
View file @
4851c202
"""Fused MoE utilities for GPTQ."""
import
functools
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
torch
.
Tensor
:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (torch.Tensor): The act_order indices.
- perm (torch.Tensor): The act_order input permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w
.
shape
[
1
]
*
16
,
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
M
,
K
=
hidden_states
.
shape
E
=
w
.
shape
[
0
]
N
=
w
.
shape
[
2
]
//
2
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
# This might not be an optimal config for a single MMM
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w
.
shape
,
w
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
(
N
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
g_idx
,
perm
,
workspace
,
M
,
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
perm1
:
torch
.
Tensor
,
perm2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation.
- perm2 (torch.Tensor): The second act_order input permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
topk
=
topk_ids
.
shape
[
1
]
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
,
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
"BLOCK_SIZE_M"
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
,
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
4851c202
...
@@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
...
@@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
return
None
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
def
get_default_config
(
dtype
:
Optional
[
str
],
M
:
int
,
is_marlin
:
bool
)
->
Dict
[
str
,
int
]:
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
)
->
Dict
[
str
,
int
]:
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
'GROUP_SIZE_M'
:
8
}
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_M'
:
16
,
...
@@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
...
@@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
return
config
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
def
try_get_optimal_moe_config
(
w2_shape
:
Tuple
[
int
,
...],
w1_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
w2_shape
:
Tuple
[
int
,
...],
dtype
:
Optional
[
str
],
top_k
:
int
,
M
:
int
,
dtype
:
Optional
[
str
],
override_config
:
Optional
[
Dict
[
str
,
M
:
int
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_marlin
:
bool
=
False
):
is_marlin
:
bool
=
False
,
):
if
override_config
:
if
override_config
:
config
=
override_config
config
=
override_config
else
:
else
:
...
@@ -391,6 +399,7 @@ def fused_topk(
...
@@ -391,6 +399,7 @@ def fused_topk(
topk
,
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
ops
.
topk_softmax
(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
...
@@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
rand_perm1
:
torch
.
Tensor
,
rand_perm2
:
torch
.
Tensor
,
topk
:
int
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
renormalize
:
bool
=
True
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
#TODO fp8 is not implemented yet
assert
not
use_fp8
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
if
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
rand_perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
rand_perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
4851c202
...
@@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
...
@@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal.
# Input scales can be loaded directly and should be equal.
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
def
_load_g_idx
(
self
,
shard_id
:
str
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
else
:
assert
shard_id
in
(
"w1"
,
"w3"
)
expert_data
.
copy_
(
loaded_weight
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
)
->
None
:
# compressed-tensors represents weights on disk which are flipped
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsMoEMethod"
)
else
loaded_weight
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
f
"got
{
shard_id
}
."
)
...
@@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
...
@@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
expert_data
=
param
.
data
[
expert_id
]
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
# is_transposed:
whether or not the parameter is transposed on disk
# is_transposed:
if the dim to shard the weight
#
If transposed, the loaded weight will be transposed and the dim
#
should be flipped. Required by GPTQ, compressed-tensors
#
to shard the loaded weight will be flipped.
#
should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
is_transposed
:
if
is_transposed
:
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
shard_dim
=
~
shard_dim
shard_dim
=
~
shard_dim
# Case weight_scales
# Case input scale: input_scale loading is only supported for fp8
if
"weight_scale"
in
weight_name
:
if
"input_scale"
in
weight_name
:
# load the weight scaling based on the quantization scheme
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
# supported weight scales can be found in
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case g_idx
if
"g_idx"
in
weight_name
:
self
.
_load_g_idx
(
shard_dim
=
0
,
shard_id
=
shard_id
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
return
# Case weight scales and zero_points
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
# specific to each case
...
@@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
return
return
# Case weight_shape
if
"weight_shape"
in
weight_name
:
if
"weight_shape"
in
weight_name
:
self
.
_load_single_value
(
param
=
param
,
# only required by compressed-tensors
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case input scale
if
"input_scale"
in
weight_name
:
# Note: input_scale loading is only supported for fp8
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
self
.
_load_single_value
(
param
=
param
,
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
expert_id
=
expert_id
)
...
@@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
...
@@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
param_data
[
expert_id
][
idx
]
=
loaded_weight
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# If we are in the row parallel case (down_proj)
else
:
else
:
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
\ No newline at end of file
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment