Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
453 additions
and
127 deletions
+453
-127
vllm/entrypoints/cli/benchmark/main.py
vllm/entrypoints/cli/benchmark/main.py
+1
-0
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+5
-2
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+2
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+26
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-4
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-1
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-5
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+1
-1
vllm/entrypoints/openai/tool_parsers/utils.py
vllm/entrypoints/openai/tool_parsers/utils.py
+1
-1
vllm/envs.py
vllm/envs.py
+12
-0
vllm/lora/layers.py
vllm/lora/layers.py
+5
-0
vllm/lora/models.py
vllm/lora/models.py
+1
-1
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
+1
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+8
-7
vllm/model_executor/guided_decoding/utils.py
vllm/model_executor/guided_decoding/utils.py
+0
-4
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+17
-1
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
...used_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
+146
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+183
-85
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+21
-6
No files found.
vllm/entrypoints/cli/benchmark/main.py
View file @
9c4ecf15
...
@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
...
@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser
=
subparsers
.
add_parser
(
bench_parser
=
subparsers
.
add_parser
(
"bench"
,
"bench"
,
help
=
"vLLM bench subcommand."
,
help
=
"vLLM bench subcommand."
,
description
=
"vLLM bench subcommand."
,
usage
=
"vllm bench <bench_type> [options]"
)
usage
=
"vllm bench <bench_type> [options]"
)
bench_subparsers
=
bench_parser
.
add_subparsers
(
required
=
True
,
bench_subparsers
=
bench_parser
.
add_subparsers
(
required
=
True
,
dest
=
"bench_type"
)
dest
=
"bench_type"
)
...
...
vllm/entrypoints/cli/openai.py
View file @
9c4ecf15
...
@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
...
@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
chat_parser
=
subparsers
.
add_parser
(
chat_parser
=
subparsers
.
add_parser
(
"chat"
,
"chat"
,
help
=
"Generate chat completions via the running API server"
,
help
=
"Generate chat completions via the running API server."
,
description
=
"Generate chat completions via the running API server."
,
usage
=
"vllm chat [options]"
)
usage
=
"vllm chat [options]"
)
_add_query_options
(
chat_parser
)
_add_query_options
(
chat_parser
)
chat_parser
.
add_argument
(
chat_parser
.
add_argument
(
...
@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
...
@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser
=
subparsers
.
add_parser
(
complete_parser
=
subparsers
.
add_parser
(
"complete"
,
"complete"
,
help
=
(
"Generate text completions based on the given prompt "
help
=
(
"Generate text completions based on the given prompt "
"via the running API server"
),
"via the running API server."
),
description
=
(
"Generate text completions based on the given prompt "
"via the running API server."
),
usage
=
"vllm complete [options]"
)
usage
=
"vllm complete [options]"
)
_add_query_options
(
complete_parser
)
_add_query_options
(
complete_parser
)
return
complete_parser
return
complete_parser
...
...
vllm/entrypoints/cli/serve.py
View file @
9c4ecf15
...
@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
...
@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
serve_parser
=
subparsers
.
add_parser
(
serve_parser
=
subparsers
.
add_parser
(
"serve"
,
"serve"
,
help
=
"Start the vLLM OpenAI Compatible API server"
,
help
=
"Start the vLLM OpenAI Compatible API server."
,
description
=
"Start the vLLM OpenAI Compatible API server."
,
usage
=
"vllm serve [model_tag] [options]"
)
usage
=
"vllm serve [model_tag] [options]"
)
serve_parser
.
add_argument
(
"model_tag"
,
serve_parser
.
add_argument
(
"model_tag"
,
type
=
str
,
type
=
str
,
...
...
vllm/entrypoints/llm.py
View file @
9c4ecf15
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import
cloudpickle
import
cloudpickle
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
.auto
import
tqdm
from
typing_extensions
import
TypeVar
,
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
@@ -117,6 +117,9 @@ class LLM:
...
@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
HuggingFace config.
...
@@ -177,6 +180,7 @@ class LLM:
...
@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture
:
int
=
8192
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
# After positional args are removed, move this right below `model`
# After positional args are removed, move this right below `model`
...
@@ -232,6 +236,7 @@ class LLM:
...
@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture
=
max_seq_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
disable_async_output_proc
=
disable_async_output_proc
,
hf_token
=
hf_token
,
hf_overrides
=
hf_overrides
,
hf_overrides
=
hf_overrides
,
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
override_pooler_config
=
override_pooler_config
,
override_pooler_config
=
override_pooler_config
,
...
@@ -531,6 +536,16 @@ class LLM:
...
@@ -531,6 +536,16 @@ class LLM:
tokenizer
.
eos_token_id
,
tokenizer
.
eos_token_id
,
length_penalty
)
length_penalty
)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if
any
(
"multi_modal_data"
in
prompt
and
prompt
[
"multi_modal_data"
]
is
not
None
for
prompt
in
prompts
):
logger
.
warning
(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!"
)
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# following the huggingface transformers implementation
...
@@ -906,6 +921,11 @@ class LLM:
...
@@ -906,6 +921,11 @@ class LLM:
if
pooling_params
is
None
:
if
pooling_params
is
None
:
# Use default pooling params.
# Use default pooling params.
pooling_params
=
PoolingParams
()
pooling_params
=
PoolingParams
()
elif
isinstance
(
pooling_params
,
PoolingParams
):
pooling_params
.
verify
(
self
.
llm_engine
.
model_config
)
else
:
for
pooling_param
in
pooling_params
:
pooling_param
.
verify
(
self
.
llm_engine
.
model_config
)
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
prompts
=
parsed_prompts
,
...
@@ -924,6 +944,8 @@ class LLM:
...
@@ -924,6 +944,8 @@ class LLM:
/
,
/
,
*
,
*
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
list
[
EmbeddingRequestOutput
]:
)
->
list
[
EmbeddingRequestOutput
]:
...
@@ -938,6 +960,8 @@ class LLM:
...
@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
...
@@ -953,6 +977,7 @@ class LLM:
...
@@ -953,6 +977,7 @@ class LLM:
items
=
self
.
encode
(
prompts
,
items
=
self
.
encode
(
prompts
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
pooling_params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
9c4ecf15
...
@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema
=
self
.
response_format
.
json_schema
json_schema
=
self
.
response_format
.
json_schema
assert
json_schema
is
not
None
assert
json_schema
is
not
None
self
.
guided_json
=
json_schema
.
json_schema
self
.
guided_json
=
json_schema
.
json_schema
if
self
.
guided_decoding_backend
is
None
:
self
.
guided_decoding_backend
=
"xgrammar"
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
...
@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
...
@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
# doc: end-embedding-extra-params
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
...
@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
...
@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return
data
return
data
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
9c4ecf15
...
@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
...
@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
truncate_tool_call_ids
,
validate_request_params
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls`
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
truncate_tool_call_ids
(
request
)
validate_request_params
(
request
)
if
(
request
.
tool_choice
==
"auto"
and
if
(
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
9c4ecf15
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return
error_check_ret
return
error_check_ret
encoding_format
=
request
.
encoding_format
encoding_format
=
request
.
encoding_format
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
"greater than max_model_len."
" Please, select a smaller truncation size."
)
" Please, select a smaller truncation size."
)
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
try
:
try
:
(
(
lora_request
,
lora_request
,
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
try
:
pooling_params
=
request
.
to_pooling_params
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
9c4ecf15
...
@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
...
@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class
PythonicToolParser
(
ToolParser
):
class
PythonicToolParser
(
ToolParser
):
"""
"""
Tool call parser for models that produce tool calls in a pythonic style,
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models.
such as Llama 3.2
and Llama 4
models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
"""
...
...
vllm/entrypoints/openai/tool_parsers/utils.py
View file @
9c4ecf15
...
@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
...
@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and
# partial_json_parser doesn't support extra data and
# JSONDeco
r
der.raw_decode doesn't support partial JSON
# JSONDecoder.raw_decode doesn't support partial JSON
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
tuple
[
Any
,
int
]:
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
tuple
[
Any
,
int
]:
try
:
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
...
...
vllm/envs.py
View file @
9c4ecf15
...
@@ -106,6 +106,7 @@ if TYPE_CHECKING:
...
@@ -106,6 +106,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_XGRAMMAR_CACHE_MB
:
int
=
0
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -665,6 +666,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -665,6 +666,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
# Use model_redirect to redirect the model name to a local folder.
# Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH"
:
"VLLM_MODEL_REDIRECT_PATH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_MODEL_REDIRECT_PATH"
,
None
),
lambda
:
os
.
environ
.
get
(
"VLLM_MODEL_REDIRECT_PATH"
,
None
),
...
@@ -692,6 +698,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -692,6 +698,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops.
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM"
:
"VLLM_USE_DEEP_GEMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_XGRAMMAR_CACHE_MB"
,
"512"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/lora/layers.py
View file @
9c4ecf15
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and
len
(
packed_modules_list
)
==
3
)
and
len
(
packed_modules_list
)
==
3
)
#TODO: Implement this
class
QKVCrossParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
pass
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
...
...
vllm/lora/models.py
View file @
9c4ecf15
...
@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
self
.
model
.
lora_manager
=
self
self
.
adapter_type
=
'LoR
a
'
self
.
adapter_type
=
'LoR
A
'
@
property
@
property
def
capacity
(
self
)
->
int
:
def
capacity
(
self
)
->
int
:
...
...
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
View file @
9c4ecf15
...
@@ -111,7 +111,7 @@ class LoRAKernelMeta:
...
@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora
# active_lora_ids, num_tokens_per_lora
lora_ids
,
num_tokens_per_lora
=
torch
.
unique
(
token_lora_mapping
,
lora_ids
,
num_tokens_per_lora
=
torch
.
unique
(
token_lora_mapping
,
sorted
=
Fals
e
,
sorted
=
Tru
e
,
return_counts
=
True
)
return_counts
=
True
)
self
.
active_lora_ids
[:
lora_ids
.
size
(
0
)].
copy_
(
lora_ids
,
self
.
active_lora_ids
[:
lora_ids
.
size
(
0
)].
copy_
(
lora_ids
,
non_blocking
=
True
)
non_blocking
=
True
)
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
9c4ecf15
...
@@ -33,6 +33,12 @@ def maybe_backend_fallback(
...
@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger
.
warning
(
"%s Falling back to use %s instead."
,
message
,
fallback
)
logger
.
warning
(
"%s Falling back to use %s instead."
,
message
,
fallback
)
guided_params
.
backend
=
fallback
guided_params
.
backend
=
fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if
guided_params
.
backend
==
"auto"
:
guided_params
.
backend
=
"xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if
guided_params
.
backend_name
==
"lm-format-enforcer"
:
if
guided_params
.
backend_name
==
"lm-format-enforcer"
:
if
guided_params
.
grammar
is
not
None
:
if
guided_params
.
grammar
is
not
None
:
...
@@ -53,14 +59,9 @@ def maybe_backend_fallback(
...
@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
xgr_installed
)
xgr_installed
)
# xgrammar doesn't support regex, fallback to outlines
if
guided_params
.
regex
is
not
None
:
fallback_or_error
(
guided_params
,
"xgrammar does not support regex guided decoding."
,
"outlines"
)
# xgrammar doesn't support some JSON schema features
# xgrammar doesn't support some JSON schema features
el
if
(
guided_params
.
json
is
not
None
if
(
guided_params
.
json
is
not
None
and
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
fallback_or_error
(
fallback_or_error
(
guided_params
,
guided_params
,
"xgrammar does not support advanced JSON schema features like "
"xgrammar does not support advanced JSON schema features like "
...
...
vllm/model_executor/guided_decoding/utils.py
View file @
9c4ecf15
...
@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
...
@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if
"pattern"
in
obj
:
if
"pattern"
in
obj
:
return
True
return
True
# Check for enum restrictions
if
"enum"
in
obj
:
return
True
# Check for numeric ranges
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
key
in
obj
for
key
in
[
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
9c4ecf15
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import
torch
import
torch
import
vllm.envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
try
:
try
:
...
@@ -131,8 +132,13 @@ class GrammarCompilerCache:
...
@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab
=
config_data
.
encoded_vocab
,
encoded_vocab
=
config_data
.
encoded_vocab
,
metadata
=
config_data
.
metadata
,
metadata
=
config_data
.
metadata
,
)
)
cache_size
=
vllm
.
envs
.
VLLM_XGRAMMAR_CACHE_MB
*
1024
*
1024
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
config
.
max_threads
)
tokenizer_info
,
max_threads
=
config
.
max_threads
,
cache_enabled
=
True
,
cache_limit_bytes
=
cache_size
,
)
return
cls
.
_cache
[
cache_key
]
return
cls
.
_cache
[
cache_key
]
...
@@ -146,6 +152,7 @@ class GrammarConfig:
...
@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str
:
str
|
None
=
None
grammar_str
:
str
|
None
=
None
json_object
:
bool
|
None
=
None
json_object
:
bool
|
None
=
None
any_whitespace
:
bool
=
True
any_whitespace
:
bool
=
True
regex_str
:
str
|
None
=
None
max_threads
:
int
=
8
max_threads
:
int
=
8
@
classmethod
@
classmethod
...
@@ -249,6 +256,13 @@ class GrammarConfig:
...
@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads
=
max_threads
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
tokenizer_data
=
tokenizer_data
,
)
)
elif
guided_params
.
regex
:
return
cls
(
regex_str
=
guided_params
.
regex
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Currently only support JSON and EBNF grammar mode for xgrammar"
"Currently only support JSON and EBNF grammar mode for xgrammar"
...
@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
...
@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self
.
ctx
=
compiler
\
self
.
ctx
=
compiler
\
.
compile_json_schema
(
'{"type": "object"}'
,
.
compile_json_schema
(
'{"type": "object"}'
,
any_whitespace
=
any_whitespace
)
any_whitespace
=
any_whitespace
)
elif
self
.
config
.
regex_str
:
self
.
ctx
=
compiler
.
compile_regex
(
self
.
config
.
regex_str
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid configuration for xgrammar logits processor"
)
"Invalid configuration for xgrammar logits processor"
)
...
...
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
0 → 100644
View file @
9c4ecf15
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
9c4ecf15
...
@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
...
@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -251,50 +254,53 @@ def fused_moe_kernel_gptq_awq(
...
@@ -251,50 +254,53 @@ def fused_moe_kernel_gptq_awq(
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel
(
def
fused_moe_kernel
(
# Pointers to matrices
# Pointers to matrices
a_ptr
,
a_ptr
,
b_ptr
,
b_ptr
,
c_ptr
,
c_ptr
,
a_scale_ptr
,
a_scale_ptr
,
b_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
# Matrix dimensions
N
,
N
,
K
,
K
,
EM
,
EM
,
num_valid_tokens
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
# (A has M rows).
stride_am
,
stride_am
,
stride_ak
,
stride_ak
,
stride_be
,
stride_be
,
stride_bk
,
stride_bk
,
stride_bn
,
stride_bn
,
stride_cm
,
stride_cm
,
stride_cn
,
stride_cn
,
stride_asm
,
stride_asm
,
stride_ask
,
stride_ask
,
stride_bse
,
stride_bse
,
stride_bsk
,
stride_bsk
,
stride_bsn
,
stride_bsn
,
# Block size for block-wise quantization
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
):
"""
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
token and expert matrices.
...
@@ -385,12 +391,23 @@ def fused_moe_kernel(
...
@@ -385,12 +391,23 @@ def fused_moe_kernel(
None
,
:]
*
stride_bsn
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
or
use_int8_w8a8
:
# block-wise
if
group_k
>
0
and
group_n
>
0
:
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
offs_bsn
*
stride_bsn
)
# channel-wise
elif
per_channel_quant
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Load per-token scale for activations
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale
=
tl
.
load
(
a_scale_ptrs
,
mask
=
token_mask
,
other
=
0.0
)[:,
None
]
# tensor-wise
else
:
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
...
@@ -414,7 +431,7 @@ def fused_moe_kernel(
...
@@ -414,7 +431,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension.
# We accumulate along the K dimension.
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
offs_ks
=
k_start
//
group_k
...
@@ -426,7 +443,11 @@ def fused_moe_kernel(
...
@@ -426,7 +443,11 @@ def fused_moe_kernel(
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
None
]
*
b_scale
[
None
,
:]
else
:
else
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
if
use_fp8_w8a8
:
# acc used to enable fp8_fast_accum
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
else
:
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
# Advance the ptrs to the next K block.
...
@@ -440,7 +461,7 @@ def fused_moe_kernel(
...
@@ -440,7 +461,7 @@ def fused_moe_kernel(
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
else
:
...
@@ -471,28 +492,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -471,28 +492,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config
:
Dict
[
str
,
Any
],
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
num_tokens
=
M
*
top_k
...
@@ -619,7 +628,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -619,7 +628,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k
=
top_k
,
top_k
=
top_k
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
per_channel_quant
=
per_channel_quant
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
**
config
,
**
config
,
)
)
...
@@ -981,8 +992,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -981,8 +992,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -995,9 +1008,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -995,9 +1008,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
per_channel_quant
,
global_num_experts
,
expert_map
,
a2_scale
,
block_shape
,
use_nn_moe
)
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1009,8 +1023,10 @@ def inplace_fused_experts_fake(
...
@@ -1009,8 +1023,10 @@ def inplace_fused_experts_fake(
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1029,6 +1045,7 @@ direct_register_custom_op(
...
@@ -1029,6 +1045,7 @@ direct_register_custom_op(
op_func
=
inplace_fused_experts
,
op_func
=
inplace_fused_experts
,
mutates_args
=
[
"hidden_states"
],
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_fake
,
fake_impl
=
inplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
)
...
@@ -1041,8 +1058,10 @@ def outplace_fused_experts(
...
@@ -1041,8 +1058,10 @@ def outplace_fused_experts(
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1055,7 +1074,8 @@ def outplace_fused_experts(
...
@@ -1055,7 +1074,8 @@ def outplace_fused_experts(
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
)
...
@@ -1069,8 +1089,10 @@ def outplace_fused_experts_fake(
...
@@ -1069,8 +1089,10 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1089,6 +1111,7 @@ direct_register_custom_op(
...
@@ -1089,6 +1111,7 @@ direct_register_custom_op(
op_func
=
outplace_fused_experts
,
op_func
=
outplace_fused_experts
,
mutates_args
=
[],
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_fake
,
fake_impl
=
outplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
)
...
@@ -1119,8 +1142,10 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1119,8 +1142,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1160,8 +1185,10 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1160,8 +1185,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
@@ -1174,6 +1201,59 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1174,6 +1201,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
def
moe_kernel_prepare_input
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# activation channel-wise int8 quantization
assert
(
per_channel_quant
),
"int8 quantization only supports block or channel-wise"
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
# activation block-wise int8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
return
A
,
A_scale
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -1183,8 +1263,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1183,8 +1263,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1294,14 +1376,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1294,14 +1376,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
A
=
curr_hidden_states
,
if
use_fp8_w8a8
:
B
=
w1
,
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
A_scale
=
a1_scale
,
curr_hidden_states
,
a1_scale
,
block_shape
)
B_scale
=
w1_scale
,
else
:
use_fp8_w8a8
=
use_fp8_w8a8
,
qcurr_hidden_states
=
curr_hidden_states
use_int8_w8a8
=
use_int8_w8a8
,
a1q_scale
=
a1_scale
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
...
@@ -1310,7 +1395,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1310,7 +1395,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
w1
,
intermediate_cache1
,
intermediate_cache1
,
a1
q
_scale
,
q
a1_scale
,
w1_scale
,
w1_scale
,
w1_zp
,
w1_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1322,8 +1407,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1322,8 +1407,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
...
@@ -1336,19 +1423,22 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1336,19 +1423,22 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else
:
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qintermediate_cache2
,
qa2_scale
=
moe_kernel_prepare_input
(
A
=
intermediate_cache2
,
if
use_fp8_w8a8
:
B
=
w2
,
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
A_scale
=
a2_scale
,
intermediate_cache2
,
a2_scale
,
block_shape
)
B_scale
=
w2_scale
,
else
:
use_fp8_w8a8
=
use_fp8_w8a8
,
qintermediate_cache2
=
intermediate_cache2
use_int8_w8a8
=
use_int8_w8a8
,
a2q_scale
=
a2_scale
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
a2
q
_scale
,
q
a2_scale
,
w2_scale
,
w2_scale
,
w2_zp
,
w2_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1360,8 +1450,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1360,8 +1450,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
...
@@ -1385,8 +1477,10 @@ def fused_moe(
...
@@ -1385,8 +1477,10 @@ def fused_moe(
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1420,6 +1514,8 @@ def fused_moe(
...
@@ -1420,6 +1514,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
activation to compute the inner products for w1 and w2.
Defaults to False.
Defaults to False.
...
@@ -1466,8 +1562,10 @@ def fused_moe(
...
@@ -1466,8 +1562,10 @@ def fused_moe(
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
9c4ecf15
...
@@ -192,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -192,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
def
forward_cuda
(
def
forward_cuda
(
...
@@ -458,7 +458,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -458,7 +458,7 @@ class FusedMoE(torch.nn.Module):
# Use expert parallelism instead of tensor parallelism?
# Use expert parallelism instead of tensor parallelism?
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
use_ep
=
(
vllm_config
.
parallel_config
.
enable_expert_parallel
use_ep
=
(
vllm_config
.
parallel_config
.
enable_expert_parallel
and
self
.
tp_size
>
1
)
and
self
.
tp_size
*
self
.
dp_size
>
1
)
# For smuggling this layer into the fused moe custom op
# For smuggling this layer into the fused moe custom op
self
.
use_direct_call
=
self
.
dp_size
==
1
self
.
use_direct_call
=
self
.
dp_size
==
1
...
@@ -542,7 +542,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -542,7 +542,9 @@ class FusedMoE(torch.nn.Module):
}
}
# need full intermediate size pre-sharding for WNA16 act order
# need full intermediate size pre-sharding for WNA16 act order
if
(
self
.
quant_method
.
__class__
.
__name__
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)):
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)):
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
...
@@ -691,9 +693,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -691,9 +693,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
# against known CompressionFormat enum values that have this quality
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
if
self
.
quant_method
.
__class__
.
__name__
in
(
self
.
quant_method
.
__class__
.
__name__
"CompressedTensorsWNA16MarlinMoEMethod"
,
==
"CompressedTensorsWNA16MoEMethod"
)
else
loaded_weight
"CompressedTensorsWNA16MoEMethod"
):
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
...
...
vllm/model_executor/layers/linear.py
View file @
9c4ecf15
...
@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix
=
f
"
{
prefix
}
.kv_proj_encoder"
)
prefix
=
f
"
{
prefix
}
.kv_proj_encoder"
)
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self
.
q_size
=
self
.
q_proj_decoder
.
output_size_per_partition
self
.
kv_size
=
self
.
kv_proj_encoder
.
num_kv_heads
*
head_size
self
.
kv_size
=
self
.
kv_proj_encoder
.
num_kv_heads
*
head_size
if
bias
:
if
bias
:
...
@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
def
process_weights_after_loading
(
self
):
for
layer
in
self
.
proj
.
values
():
if
self
.
quant_method
is
not
None
:
self
.
quant_method
.
process_weights_after_loading
(
layer
)
@
property
@
property
def
q_proj_decoder
(
self
)
->
ColumnParallelLinear
:
def
q_proj_decoder
(
self
)
->
ColumnParallelLinear
:
layer
=
self
.
proj
[
"q_proj_decoder"
]
layer
=
self
.
proj
[
"q_proj_decoder"
]
for
name
,
param
in
self
.
named_parameters
():
for
name
,
param
in
self
.
named_parameters
():
target_param
=
getattr
(
layer
,
name
)
target_param
=
getattr
(
layer
,
name
,
None
)
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"q_proj_decoder"
)
if
target_param
is
not
None
:
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"q_proj_decoder"
)
return
layer
return
layer
@
property
@
property
def
kv_proj_encoder
(
self
)
->
QKVParallelLinear
:
def
kv_proj_encoder
(
self
)
->
QKVParallelLinear
:
layer
=
self
.
proj
[
"kv_proj_encoder"
]
layer
=
self
.
proj
[
"kv_proj_encoder"
]
for
name
,
param
in
self
.
named_parameters
():
for
name
,
param
in
self
.
named_parameters
():
target_param
=
getattr
(
layer
,
name
)
target_param
=
getattr
(
layer
,
name
,
None
)
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"kv_proj_encoder"
)
if
target_param
is
not
None
:
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"kv_proj_encoder"
)
return
layer
return
layer
def
sync_weight_attrs
(
def
sync_weight_attrs
(
...
@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
if
loaded_shard_id
==
"q"
else
self
.
kv_proj_encoder
)
if
loaded_shard_id
==
"q"
else
self
.
kv_proj_encoder
)
target_param
=
self
.
select_proj_params
(
layer
,
param
)
target_param
=
self
.
select_proj_params
(
layer
,
param
)
shard_id_args
=
(
loaded_shard_id
,
)
if
loaded_shard_id
!=
"q"
else
()
shard_id_args
=
(
loaded_shard_id
,
)
if
loaded_shard_id
!=
"q"
else
()
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
:
layer
.
weight_loader_v2
(
target_param
,
loaded_weight
,
*
shard_id_args
)
else
:
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", q_size=
{
self
.
q_
proj_decoder
.
output_size_per_partition
}
"
s
+=
f
", q_size=
{
self
.
q_
size
}
"
s
+=
f
", kv_size=
{
self
.
kv_size
}
"
s
+=
f
", kv_size=
{
self
.
kv_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment