Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4851c202
Commit
4851c202
authored
Sep 13, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1' into v0.6.1-dev
parents
9b902f9e
3fd2b0d2
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
672 additions
and
241 deletions
+672
-241
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+25
-9
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+38
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+10
-7
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+77
-10
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+35
-19
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+78
-1
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+18
-7
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
+0
-1
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+5
-15
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+8
-19
vllm/envs.py
vllm/envs.py
+12
-0
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+20
-1
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+6
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+21
-2
vllm/lora/layers.py
vllm/lora/layers.py
+1
-1
vllm/lora/request.py
vllm/lora/request.py
+18
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+10
-4
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+219
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+20
-118
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+51
-24
No files found.
vllm/entrypoints/llm.py
View file @
4851c202
...
...
@@ -6,7 +6,8 @@ from tqdm import tqdm
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
...
...
@@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
...
...
@@ -55,7 +56,7 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq",
"squeezellm",
and "fp8" (experimental).
we support "awq", "gptq", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
...
...
@@ -393,12 +394,21 @@ class LLM:
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
prompt
=
apply_chat_template
(
tokenizer
,
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
messages
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
inputs
:
PromptInputs
if
is_list_of
(
prompt
,
int
):
...
...
@@ -560,6 +570,12 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
EmbeddingRequestOutput
)
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
def
stop_profile
(
self
)
->
None
:
self
.
llm_engine
.
stop_profile
()
# LEGACY
def
_convert_v1_inputs
(
self
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
4851c202
...
...
@@ -35,11 +35,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeResponse
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
LoadLoraAdapterRequest
,
TokenizeRequest
,
TokenizeResponse
)
# yapf: enable
TokenizeResponse
,
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
...
...
@@ -343,6 +345,40 @@ if envs.VLLM_TORCH_PROFILER_DIR:
return
Response
(
status_code
=
200
)
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
logger
.
warning
(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
return
Response
(
status_code
=
200
,
content
=
response
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
return
Response
(
status_code
=
200
,
content
=
response
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
4851c202
...
...
@@ -713,13 +713,6 @@ class DeltaToolCall(OpenAIBaseModel):
function
:
Optional
[
DeltaFunctionCall
]
=
None
# the initial delta that gets sent once a new tool call is started;
class
InitialDeltaToolCall
(
DeltaToolCall
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-tool-
{
random_uuid
()
}
"
)
type
:
Literal
[
"function"
]
=
"function"
index
:
int
class
ExtractedToolCallInformation
(
BaseModel
):
# indicate if tools were called
tools_called
:
bool
...
...
@@ -878,3 +871,13 @@ class DetokenizeRequest(OpenAIBaseModel):
class
DetokenizeResponse
(
OpenAIBaseModel
):
prompt
:
str
class
LoadLoraAdapterRequest
(
BaseModel
):
lora_name
:
str
lora_path
:
str
class
UnloadLoraAdapterRequest
(
BaseModel
):
lora_name
:
str
lora_int_id
:
Optional
[
int
]
=
Field
(
default
=
None
)
vllm/entrypoints/openai/run_batch.py
View file @
4851c202
import
asyncio
from
http
import
HTTPStatus
from
io
import
StringIO
from
typing
import
Awaitable
,
Callable
,
List
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
import
aiohttp
import
torch
from
prometheus_client
import
start_http_server
from
tqdm
import
tqdm
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
...
@@ -78,6 +81,38 @@ def parse_args():
return
parser
.
parse_args
()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT
=
"{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]
\n
"
# noqa: E501
class
BatchProgressTracker
:
def
__init__
(
self
):
self
.
_total
=
0
self
.
_pbar
:
Optional
[
tqdm
]
=
None
def
submitted
(
self
):
self
.
_total
+=
1
def
completed
(
self
):
if
self
.
_pbar
:
self
.
_pbar
.
update
()
def
pbar
(
self
)
->
tqdm
:
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
self
.
_pbar
=
tqdm
(
total
=
self
.
_total
,
unit
=
"req"
,
desc
=
"Running batch"
,
mininterval
=
5
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
)
return
self
.
_pbar
async
def
read_file
(
path_or_url
:
str
)
->
str
:
if
path_or_url
.
startswith
(
"http://"
)
or
path_or_url
.
startswith
(
"https://"
):
async
with
aiohttp
.
ClientSession
()
as
session
,
\
...
...
@@ -101,8 +136,28 @@ async def write_file(path_or_url: str, data: str) -> None:
f
.
write
(
data
)
def
make_error_request_output
(
request
:
BatchRequestInput
,
error_msg
:
str
)
->
BatchRequestOutput
:
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
response
=
BatchResponseData
(
status_code
=
HTTPStatus
.
BAD_REQUEST
,
request_id
=
f
"vllm-batch-
{
random_uuid
()
}
"
,
),
error
=
error_msg
,
)
return
batch_output
async
def
make_async_error_request_output
(
request
:
BatchRequestInput
,
error_msg
:
str
)
->
BatchRequestOutput
:
return
make_error_request_output
(
request
,
error_msg
)
async
def
run_request
(
serving_engine_func
:
Callable
,
request
:
BatchRequestInput
)
->
BatchRequestOutput
:
request
:
BatchRequestInput
,
tracker
:
BatchProgressTracker
)
->
BatchRequestOutput
:
response
=
await
serving_engine_func
(
request
.
body
)
if
isinstance
(
response
,
(
ChatCompletionResponse
,
EmbeddingResponse
)):
...
...
@@ -123,8 +178,10 @@ async def run_request(serving_engine_func: Callable,
error
=
response
,
)
else
:
raise
ValueError
(
"Request must not be sent in stream mode"
)
batch_output
=
make_error_request_output
(
request
,
error_msg
=
"Request must not be sent in stream mode"
)
tracker
.
completed
()
return
batch_output
...
...
@@ -164,6 +221,9 @@ async def main(args):
request_logger
=
request_logger
,
)
tracker
=
BatchProgressTracker
()
logger
.
info
(
"Reading batch from %s..."
,
args
.
input_file
)
# Submit all requests in the file to the engine "concurrently".
response_futures
:
List
[
Awaitable
[
BatchRequestOutput
]]
=
[]
for
request_json
in
(
await
read_file
(
args
.
input_file
)).
strip
().
split
(
"
\n
"
):
...
...
@@ -178,16 +238,23 @@ async def main(args):
if
request
.
url
==
"/v1/chat/completions"
:
response_futures
.
append
(
run_request
(
openai_serving_chat
.
create_chat_completion
,
request
))
request
,
tracker
))
tracker
.
submitted
()
elif
request
.
url
==
"/v1/embeddings"
:
response_futures
.
append
(
run_request
(
openai_serving_embedding
.
create_embedding
,
request
))
run_request
(
openai_serving_embedding
.
create_embedding
,
request
,
tracker
))
tracker
.
submitted
()
else
:
raise
ValueError
(
"Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint."
)
responses
=
await
asyncio
.
gather
(
*
response_futures
)
response_futures
.
append
(
make_async_error_request_output
(
request
,
error_msg
=
"Only /v1/chat/completions and "
"/v1/embeddings are supported in the batch endpoint."
,
))
with
tracker
.
pbar
():
responses
=
await
asyncio
.
gather
(
*
response_futures
)
output_buffer
=
StringIO
()
for
response
in
responses
:
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
4851c202
...
...
@@ -11,7 +11,8 @@ from fastapi import Request
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
parse_chat_messages_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
...
...
@@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
iterate_with_cancellation
,
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -121,15 +122,27 @@ class OpenAIServingChat(OpenAIServing):
tool
.
model_dump
()
for
tool
in
request
.
tools
]
prompt
=
apply_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
request
.
messages
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
except
Exception
as
e
:
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
...
...
@@ -271,9 +284,13 @@ class OpenAIServingChat(OpenAIServing):
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for
i
in
range
(
num_choices
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
delta
=
DeltaMessage
(
role
=
role
,
content
=
""
,
),
logprobs
=
None
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
...
...
@@ -303,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# last message
if
request
.
echo
:
last_msg_content
:
Optional
[
str
]
=
""
if
conversation
and
conversation
[
-
1
].
get
(
"content"
)
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
last_msg_content
:
str
=
""
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
if
last_msg_content
:
for
i
in
range
(
num_choices
):
...
...
@@ -655,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
if
request
.
echo
:
last_msg_content
=
""
if
conversation
and
conversation
[
-
1
]
.
get
(
"content"
)
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
for
choice
in
choices
:
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
4851c202
...
...
@@ -16,11 +16,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
ErrorResponse
,
LoadLoraAdapterRequest
,
ModelCard
,
ModelList
,
ModelPermission
,
TokenizeChatRequest
,
TokenizeCompletionRequest
,
TokenizeRequest
)
TokenizeRequest
,
UnloadLoraAdapterRequest
)
# yapf: enable
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
...
...
@@ -32,6 +34,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
AtomicCounter
logger
=
init_logger
(
__name__
)
...
...
@@ -78,6 +81,7 @@ class OpenAIServing:
self
.
served_model_names
=
served_model_names
self
.
lora_id_counter
=
AtomicCounter
(
0
)
self
.
lora_requests
=
[]
if
lora_modules
is
not
None
:
self
.
lora_requests
=
[
...
...
@@ -403,3 +407,76 @@ class OpenAIServing:
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
tokenizer
.
decode
(
token_id
)
async
def
_check_load_lora_adapter_request
(
self
,
request
:
LoadLoraAdapterRequest
)
->
Optional
[
ErrorResponse
]:
# Check if both 'lora_name' and 'lora_path' are provided
if
not
request
.
lora_name
or
not
request
.
lora_path
:
return
self
.
create_error_response
(
message
=
"Both 'lora_name' and 'lora_path' must be provided."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# Check if the lora adapter with the given name already exists
if
any
(
lora_request
.
lora_name
==
request
.
lora_name
for
lora_request
in
self
.
lora_requests
):
return
self
.
create_error_response
(
message
=
f
"The lora adapter '
{
request
.
lora_name
}
' has already been"
"loaded."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
return
None
async
def
_check_unload_lora_adapter_request
(
self
,
request
:
UnloadLoraAdapterRequest
)
->
Optional
[
ErrorResponse
]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if
not
request
.
lora_name
and
not
request
.
lora_int_id
:
return
self
.
create_error_response
(
message
=
"either 'lora_name' and 'lora_int_id' needs to be provided."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# Check if the lora adapter with the given name exists
if
not
any
(
lora_request
.
lora_name
==
request
.
lora_name
for
lora_request
in
self
.
lora_requests
):
return
self
.
create_error_response
(
message
=
f
"The lora adapter '
{
request
.
lora_name
}
' cannot be found."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
return
None
async
def
load_lora_adapter
(
self
,
request
:
LoadLoraAdapterRequest
)
->
Union
[
ErrorResponse
,
str
]:
error_check_ret
=
await
self
.
_check_load_lora_adapter_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
lora_name
,
lora_path
=
request
.
lora_name
,
request
.
lora_path
unique_id
=
self
.
lora_id_counter
.
inc
(
1
)
self
.
lora_requests
.
append
(
LoRARequest
(
lora_name
=
lora_name
,
lora_int_id
=
unique_id
,
lora_path
=
lora_path
))
return
f
"Success: LoRA adapter '
{
lora_name
}
' added successfully."
async
def
unload_lora_adapter
(
self
,
request
:
UnloadLoraAdapterRequest
)
->
Union
[
ErrorResponse
,
str
]:
error_check_ret
=
await
self
.
_check_unload_lora_adapter_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
lora_name
=
request
.
lora_name
self
.
lora_requests
=
[
lora_request
for
lora_request
in
self
.
lora_requests
if
lora_request
.
lora_name
!=
lora_name
]
return
f
"Success: LoRA adapter '
{
lora_name
}
' removed successfully."
vllm/entrypoints/openai/serving_tokenization.py
View file @
4851c202
...
...
@@ -2,7 +2,8 @@ from typing import List, Optional, Union
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
parse_chat_messages_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
...
...
@@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
MistralTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
...
...
@@ -77,12 +80,20 @@ class OpenAIServingTokenization(OpenAIServing):
logger
.
warning
(
"Multi-modal inputs are ignored during tokenization"
)
prompt
=
apply_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
)
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
request
.
messages
,
chat_template
=
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
)
else
:
prompt
=
request
.
prompt
...
...
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
View file @
4851c202
...
...
@@ -20,7 +20,6 @@ class ToolParser:
# the index of the tool call that is currently being parsed
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[]
self
.
model_tokenizer
=
tokenizer
...
...
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
View file @
4851c202
...
...
@@ -8,14 +8,14 @@ from partial_json_parser.core.options import Allow
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
InitialDeltaToolCall
,
ToolCall
)
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
)
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -34,7 +34,6 @@ class Hermes2ProToolParser(ToolParser):
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
...
...
@@ -168,7 +167,6 @@ class Hermes2ProToolParser(ToolParser):
# set cursors and state appropriately
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"Starting on a new tool %s"
,
self
.
current_tool_id
)
...
...
@@ -218,24 +216,16 @@ class Hermes2ProToolParser(ToolParser):
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
return
None
# case - we haven't sent the initial delta with the tool call ID
# (it will be sent)
if
not
self
.
current_tool_initial_sent
:
self
.
current_tool_initial_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
InitialDeltaToolCall
(
index
=
self
.
current_tool_id
).
model_dump
(
exclude_none
=
True
)
])
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
el
if
not
self
.
current_tool_name_sent
:
if
not
self
.
current_tool_name_sent
:
function_name
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
self
.
current_tool_name_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
f
"chatcmpl-tool-
{
random_uuid
()
}
"
,
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
4851c202
...
...
@@ -8,14 +8,14 @@ from partial_json_parser.core.options import Allow
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
InitialDeltaToolCall
,
ToolCall
)
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
)
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -25,7 +25,7 @@ class MistralToolParser(ToolParser):
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser
g
mistral are all set
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
...
...
@@ -42,7 +42,6 @@ class MistralToolParser(ToolParser):
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"[TOOL_CALLS]"
...
...
@@ -91,7 +90,6 @@ class MistralToolParser(ToolParser):
except
Exception
as
e
:
logger
.
error
(
"Error in extracting tool call from response: %s"
,
e
)
print
(
"ERROR"
,
e
)
# return information to just treat the tool call as regular JSON
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
...
...
@@ -109,7 +107,7 @@ class MistralToolParser(ToolParser):
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if
self
.
bot_token
_id
not
in
current_t
oken_ids
:
if
self
.
bot_token
not
in
current_t
ext
:
return
DeltaMessage
(
content
=
delta_text
)
# if the tool call token ID IS in the tokens generated so far, that
...
...
@@ -134,7 +132,7 @@ class MistralToolParser(ToolParser):
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr
=
current_text
.
split
(
self
.
bot_token
)[
1
]
parsable_arr
=
current_text
.
split
(
self
.
bot_token
)[
-
1
]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
...
...
@@ -186,31 +184,22 @@ class MistralToolParser(ToolParser):
# re-set stuff pertaining to progress in the current tool
self
.
current_tool_id
=
len
(
tool_call_arr
)
-
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"starting on new tool %d"
,
self
.
current_tool_id
)
return
delta
# case: update an existing tool - this is handled below
# if the current tool initial data incl. the id, type=function
# and idx not sent, send that
if
not
self
.
current_tool_initial_sent
:
self
.
current_tool_initial_sent
=
True
delta
=
DeltaMessage
(
tool_calls
=
[
InitialDeltaToolCall
(
index
=
self
.
current_tool_id
).
model_dump
(
exclude_none
=
True
)
])
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
el
if
not
self
.
current_tool_name_sent
:
if
not
self
.
current_tool_name_sent
:
function_name
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
f
"chatcmpl-tool-
{
random_uuid
()
}
"
,
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
...
...
vllm/envs.py
View file @
4851c202
...
...
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_ENGINE_USE_RAY
:
bool
=
False
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -226,6 +227,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK"
:
...
...
@@ -433,6 +439,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_AWQ"
,
"0"
))),
# If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
}
# end-env-vars-definition
...
...
vllm/executor/cpu_executor.py
View file @
4851c202
...
...
@@ -5,7 +5,8 @@ from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
import
torch
import
vllm.envs
as
envs
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
...
...
@@ -60,6 +61,8 @@ class CPUExecutor(ExecutorBase):
self
.
cache_config
=
_verify_and_get_cache_config
(
self
.
cache_config
)
self
.
scheduler_config
=
_verify_and_get_scheduler_config
(
self
.
scheduler_config
)
self
.
parallel_config
=
_verify_and_get_parallel_config
(
self
.
parallel_config
)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
...
...
@@ -296,6 +299,12 @@ class CPUExecutor(ExecutorBase):
for
result
in
parallel_worker_tasks
:
result
.
get
()
def
start_profile
(
self
)
->
None
:
self
.
driver_method_invoker
(
self
.
driver_worker
,
"start_profile"
)
def
stop_profile
(
self
)
->
None
:
self
.
driver_method_invoker
(
self
.
driver_worker
,
"stop_profile"
)
class
CPUExecutorAsync
(
CPUExecutor
,
ExecutorAsyncBase
):
...
...
@@ -353,6 +362,16 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
return
config
def
_verify_and_get_parallel_config
(
config
:
ParallelConfig
)
->
ParallelConfig
:
if
(
config
.
distributed_executor_backend
is
not
None
and
config
.
distributed_executor_backend
!=
"mp"
):
logger
.
warning
(
"%s is not supported on CPU, fallback to mp distributed executor "
"backend."
,
config
.
distributed_executor_backend
)
config
.
distributed_executor_backend
=
"mp"
return
config
def
_driver_method_invoker
(
driver
,
method
:
str
,
*
args
,
**
kwargs
):
return
getattr
(
driver
,
method
)(
*
args
,
**
kwargs
)
...
...
vllm/executor/gpu_executor.py
View file @
4851c202
...
...
@@ -169,6 +169,12 @@ class GPUExecutor(ExecutorBase):
# it's running.
return
def
start_profile
(
self
)
->
None
:
self
.
driver_worker
.
start_profile
()
def
stop_profile
(
self
)
->
None
:
self
.
driver_worker
.
stop_profile
()
class
GPUExecutorAsync
(
GPUExecutor
,
ExecutorAsyncBase
):
...
...
vllm/executor/ray_gpu_executor.py
View file @
4851c202
...
...
@@ -242,6 +242,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
**
({
"VLLM_ATTENTION_BACKEND"
:
envs
.
VLLM_ATTENTION_BACKEND
}
if
envs
.
VLLM_ATTENTION_BACKEND
is
not
None
else
{})
},
)
for
(
node_id
,
_
)
in
worker_node_and_gpu_ids
]
self
.
_env_vars_for_all_workers
=
(
...
...
@@ -427,18 +430,34 @@ class RayGPUExecutor(DistributedGPUExecutor):
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
def
_c
ompiled
_ray_dag
(
self
,
enable_asyncio
:
bool
):
def
_c
heck
_ray_
a
dag
_installation
(
self
):
import
pkg_resources
from
packaging
import
version
required_version
=
version
.
parse
(
"2.3
2
"
)
required_version
=
version
.
parse
(
"2.3
5
"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
import
importlib.util
adag_spec
=
importlib
.
util
.
find_spec
(
"ray.experimental.compiled_dag_ref"
)
if
adag_spec
is
None
:
raise
ValueError
(
"Ray accelerated DAG is not installed. "
"Run `pip install ray[adag]` to install it."
)
cupy_spec
=
importlib
.
util
.
find_spec
(
"cupy"
)
if
cupy_spec
is
None
and
envs
.
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL
:
raise
ValueError
(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation."
)
def
_compiled_ray_dag
(
self
,
enable_asyncio
:
bool
):
assert
self
.
parallel_config
.
use_ray
self
.
_check_ray_adag_installation
()
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.experimental.channel.torch_tensor_type
import
TorchTensorType
...
...
vllm/lora/layers.py
View file @
4851c202
...
...
@@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear
if
hasattr
(
base_layer
,
"weight"
):
return
base_layer
.
weight
.
device
# GPTQ/AWQ
/SqueezeLLM
# GPTQ/AWQ
elif
hasattr
(
base_layer
,
"qweight"
):
return
base_layer
.
qweight
.
device
# marlin
...
...
vllm/lora/request.py
View file @
4851c202
...
...
@@ -28,7 +28,6 @@ class LoRARequest(
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
__hash__
=
AdapterRequest
.
__hash__
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__struct_fields__
:
...
...
@@ -75,3 +74,21 @@ class LoRARequest(
DeprecationWarning
,
stacklevel
=
2
)
self
.
lora_path
=
value
def
__eq__
(
self
,
value
:
object
)
->
bool
:
"""
Overrides the equality method to compare LoRARequest
instances based on lora_name. This allows for identification
and comparison lora adapter across engines.
"""
return
isinstance
(
value
,
self
.
__class__
)
and
self
.
lora_name
==
value
.
lora_name
def
__hash__
(
self
)
->
int
:
"""
Overrides the hash method to hash LoRARequest instances
based on lora_name. This ensures that LoRARequest instances
can be used in hash-based collections such as sets and dictionaries,
identified by their names across engines.
"""
return
hash
(
self
.
lora_name
)
vllm/model_executor/layers/fused_moe/__init__.py
View file @
4851c202
...
...
@@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.triton_utils
import
HAS_TRITON
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
]
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
,
]
if
HAS_TRITON
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_
marlin_
moe
,
fused_
moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
fused_experts
,
fused_moe
,
fused_
topk
,
get_config_file_name
,
grouped_topk
)
__all__
+=
[
"fused_marlin_moe"
,
"single_marlin_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_experts"
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
0 → 100644
View file @
4851c202
"""Fused MoE utilities for GPTQ."""
import
functools
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
torch
.
Tensor
:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (torch.Tensor): The act_order indices.
- perm (torch.Tensor): The act_order input permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w
.
shape
[
1
]
*
16
,
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
M
,
K
=
hidden_states
.
shape
E
=
w
.
shape
[
0
]
N
=
w
.
shape
[
2
]
//
2
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
# This might not be an optimal config for a single MMM
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w
.
shape
,
w
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
(
N
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
g_idx
,
perm
,
workspace
,
M
,
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
perm1
:
torch
.
Tensor
,
perm2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation.
- perm2 (torch.Tensor): The second act_order input permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
topk
=
topk_ids
.
shape
[
1
]
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
,
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
"BLOCK_SIZE_M"
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
,
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
4851c202
...
...
@@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
is_marlin
:
bool
)
->
Dict
[
str
,
int
]:
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
)
->
Dict
[
str
,
int
]:
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
config
=
{
'BLOCK_SIZE_M'
:
16
,
...
...
@@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_marlin
:
bool
=
False
):
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_marlin
:
bool
=
False
,
):
if
override_config
:
config
=
override_config
else
:
...
...
@@ -391,6 +399,7 @@ def fused_topk(
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
...
...
@@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
,
topk_ids
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
rand_perm1
:
torch
.
Tensor
,
rand_perm2
:
torch
.
Tensor
,
topk
:
int
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
renormalize
:
bool
=
True
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
#TODO fp8 is not implemented yet
assert
not
use_fp8
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
if
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
rand_perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
rand_perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
4851c202
...
...
@@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal.
param_data
[
expert_id
]
=
loaded_weight
def
_load_g_idx
(
self
,
shard_id
:
str
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
else
:
assert
shard_id
in
(
"w1"
,
"w3"
)
expert_data
.
copy_
(
loaded_weight
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
# compressed-tensors represents weights on disk which are flipped
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsMoEMethod"
)
else
loaded_weight
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
...
...
@@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
# is_transposed:
whether or not the parameter is transposed on disk
#
If transposed, the loaded weight will be transposed and the dim
#
to shard the loaded weight will be flipped.
# is_transposed:
if the dim to shard the weight
#
should be flipped. Required by GPTQ, compressed-tensors
#
should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
is_transposed
:
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
shard_dim
=
~
shard_dim
# Case weight_scales
if
"weight_scale"
in
weight_name
:
# load the weight scaling based on the quantization scheme
# supported weight scales can be found in
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case g_idx
if
"g_idx"
in
weight_name
:
self
.
_load_g_idx
(
shard_dim
=
0
,
shard_id
=
shard_id
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
return
# Case weight scales and zero_points
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
...
...
@@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
return
# Case weight_shape
if
"weight_shape"
in
weight_name
:
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case input scale
if
"input_scale"
in
weight_name
:
# Note: input_scale loading is only supported for fp8
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
# only required by compressed-tensors
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
...
...
@@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
else
:
param_data
[
expert_id
]
=
loaded_weight
\ No newline at end of file
param_data
[
expert_id
]
=
loaded_weight
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment