Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
453 additions
and
127 deletions
+453
-127
vllm/entrypoints/cli/benchmark/main.py
vllm/entrypoints/cli/benchmark/main.py
+1
-0
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+5
-2
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+2
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+26
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-4
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-1
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-5
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+1
-1
vllm/entrypoints/openai/tool_parsers/utils.py
vllm/entrypoints/openai/tool_parsers/utils.py
+1
-1
vllm/envs.py
vllm/envs.py
+12
-0
vllm/lora/layers.py
vllm/lora/layers.py
+5
-0
vllm/lora/models.py
vllm/lora/models.py
+1
-1
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
+1
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+8
-7
vllm/model_executor/guided_decoding/utils.py
vllm/model_executor/guided_decoding/utils.py
+0
-4
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+17
-1
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
...used_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
+146
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+183
-85
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+21
-6
No files found.
vllm/entrypoints/cli/benchmark/main.py
View file @
9c4ecf15
...
...
@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser
=
subparsers
.
add_parser
(
"bench"
,
help
=
"vLLM bench subcommand."
,
description
=
"vLLM bench subcommand."
,
usage
=
"vllm bench <bench_type> [options]"
)
bench_subparsers
=
bench_parser
.
add_subparsers
(
required
=
True
,
dest
=
"bench_type"
)
...
...
vllm/entrypoints/cli/openai.py
View file @
9c4ecf15
...
...
@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
chat_parser
=
subparsers
.
add_parser
(
"chat"
,
help
=
"Generate chat completions via the running API server"
,
help
=
"Generate chat completions via the running API server."
,
description
=
"Generate chat completions via the running API server."
,
usage
=
"vllm chat [options]"
)
_add_query_options
(
chat_parser
)
chat_parser
.
add_argument
(
...
...
@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser
=
subparsers
.
add_parser
(
"complete"
,
help
=
(
"Generate text completions based on the given prompt "
"via the running API server"
),
"via the running API server."
),
description
=
(
"Generate text completions based on the given prompt "
"via the running API server."
),
usage
=
"vllm complete [options]"
)
_add_query_options
(
complete_parser
)
return
complete_parser
...
...
vllm/entrypoints/cli/serve.py
View file @
9c4ecf15
...
...
@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
serve_parser
=
subparsers
.
add_parser
(
"serve"
,
help
=
"Start the vLLM OpenAI Compatible API server"
,
help
=
"Start the vLLM OpenAI Compatible API server."
,
description
=
"Start the vLLM OpenAI Compatible API server."
,
usage
=
"vllm serve [model_tag] [options]"
)
serve_parser
.
add_argument
(
"model_tag"
,
type
=
str
,
...
...
vllm/entrypoints/llm.py
View file @
9c4ecf15
...
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import
cloudpickle
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
.auto
import
tqdm
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
...
@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
...
...
@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
# After positional args are removed, move this right below `model`
...
...
@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
hf_token
=
hf_token
,
hf_overrides
=
hf_overrides
,
mm_processor_kwargs
=
mm_processor_kwargs
,
override_pooler_config
=
override_pooler_config
,
...
...
@@ -531,6 +536,16 @@ class LLM:
tokenizer
.
eos_token_id
,
length_penalty
)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if
any
(
"multi_modal_data"
in
prompt
and
prompt
[
"multi_modal_data"
]
is
not
None
for
prompt
in
prompts
):
logger
.
warning
(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!"
)
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
...
...
@@ -906,6 +921,11 @@ class LLM:
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
elif
isinstance
(
pooling_params
,
PoolingParams
):
pooling_params
.
verify
(
self
.
llm_engine
.
model_config
)
else
:
for
pooling_param
in
pooling_params
:
pooling_param
.
verify
(
self
.
llm_engine
.
model_config
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
...
...
@@ -924,6 +944,8 @@ class LLM:
/
,
*
,
use_tqdm
:
bool
=
True
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
list
[
EmbeddingRequestOutput
]:
...
...
@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
...
...
@@ -953,6 +977,7 @@ class LLM:
items
=
self
.
encode
(
prompts
,
use_tqdm
=
use_tqdm
,
pooling_params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
9c4ecf15
...
...
@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema
=
self
.
response_format
.
json_schema
assert
json_schema
is
not
None
self
.
guided_json
=
json_schema
.
json_schema
if
self
.
guided_decoding_backend
is
None
:
self
.
guided_decoding_backend
=
"xgrammar"
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
...
...
@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
...
...
@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return
data
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
9c4ecf15
...
...
@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
truncate_tool_call_ids
,
validate_request_params
)
logger
=
init_logger
(
__name__
)
...
...
@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
validate_request_params
(
request
)
if
(
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
9c4ecf15
...
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return
error_check_ret
encoding_format
=
request
.
encoding_format
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
...
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
" Please, select a smaller truncation size."
)
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
try
:
(
lora_request
,
...
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
pooling_params
=
request
.
to_pooling_params
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
9c4ecf15
...
...
@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class
PythonicToolParser
(
ToolParser
):
"""
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models.
such as Llama 3.2
and Llama 4
models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
...
...
vllm/entrypoints/openai/tool_parsers/utils.py
View file @
9c4ecf15
...
...
@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and
# JSONDeco
r
der.raw_decode doesn't support partial JSON
# JSONDecoder.raw_decode doesn't support partial JSON
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
tuple
[
Any
,
int
]:
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
...
...
vllm/envs.py
View file @
9c4ecf15
...
...
@@ -106,6 +106,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_XGRAMMAR_CACHE_MB
:
int
=
0
def
get_default_cache_root
():
...
...
@@ -665,6 +666,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
# Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_MODEL_REDIRECT_PATH"
,
None
),
...
...
@@ -692,6 +698,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_XGRAMMAR_CACHE_MB"
,
"512"
)),
}
# end-env-vars-definition
...
...
vllm/lora/layers.py
View file @
9c4ecf15
...
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and
len
(
packed_modules_list
)
==
3
)
#TODO: Implement this
class
QKVCrossParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
pass
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
...
...
vllm/lora/models.py
View file @
9c4ecf15
...
...
@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
self
.
adapter_type
=
'LoR
a
'
self
.
adapter_type
=
'LoR
A
'
@
property
def
capacity
(
self
)
->
int
:
...
...
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
View file @
9c4ecf15
...
...
@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora
lora_ids
,
num_tokens_per_lora
=
torch
.
unique
(
token_lora_mapping
,
sorted
=
Fals
e
,
sorted
=
Tru
e
,
return_counts
=
True
)
self
.
active_lora_ids
[:
lora_ids
.
size
(
0
)].
copy_
(
lora_ids
,
non_blocking
=
True
)
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
9c4ecf15
...
...
@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger
.
warning
(
"%s Falling back to use %s instead."
,
message
,
fallback
)
guided_params
.
backend
=
fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if
guided_params
.
backend
==
"auto"
:
guided_params
.
backend
=
"xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if
guided_params
.
backend_name
==
"lm-format-enforcer"
:
if
guided_params
.
grammar
is
not
None
:
...
...
@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
xgr_installed
)
# xgrammar doesn't support regex, fallback to outlines
if
guided_params
.
regex
is
not
None
:
fallback_or_error
(
guided_params
,
"xgrammar does not support regex guided decoding."
,
"outlines"
)
# xgrammar doesn't support some JSON schema features
el
if
(
guided_params
.
json
is
not
None
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
if
(
guided_params
.
json
is
not
None
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
fallback_or_error
(
guided_params
,
"xgrammar does not support advanced JSON schema features like "
...
...
vllm/model_executor/guided_decoding/utils.py
View file @
9c4ecf15
...
...
@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if
"pattern"
in
obj
:
return
True
# Check for enum restrictions
if
"enum"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
9c4ecf15
...
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import
torch
import
vllm.envs
from
vllm.logger
import
init_logger
try
:
...
...
@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab
=
config_data
.
encoded_vocab
,
metadata
=
config_data
.
metadata
,
)
cache_size
=
vllm
.
envs
.
VLLM_XGRAMMAR_CACHE_MB
*
1024
*
1024
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
config
.
max_threads
)
tokenizer_info
,
max_threads
=
config
.
max_threads
,
cache_enabled
=
True
,
cache_limit_bytes
=
cache_size
,
)
return
cls
.
_cache
[
cache_key
]
...
...
@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str
:
str
|
None
=
None
json_object
:
bool
|
None
=
None
any_whitespace
:
bool
=
True
regex_str
:
str
|
None
=
None
max_threads
:
int
=
8
@
classmethod
...
...
@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
)
elif
guided_params
.
regex
:
return
cls
(
regex_str
=
guided_params
.
regex
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
)
else
:
raise
ValueError
(
"Currently only support JSON and EBNF grammar mode for xgrammar"
...
...
@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self
.
ctx
=
compiler
\
.
compile_json_schema
(
'{"type": "object"}'
,
any_whitespace
=
any_whitespace
)
elif
self
.
config
.
regex_str
:
self
.
ctx
=
compiler
.
compile_regex
(
self
.
config
.
regex_str
)
else
:
raise
ValueError
(
"Invalid configuration for xgrammar logits processor"
)
...
...
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
0 → 100644
View file @
9c4ecf15
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
9c4ecf15
...
...
@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -251,50 +254,53 @@ def fused_moe_kernel_gptq_awq(
@
triton
.
jit
def
fused_moe_kernel
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
...
...
@@ -385,12 +391,23 @@ def fused_moe_kernel(
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
or
use_int8_w8a8
:
# block-wise
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
# channel-wise
elif
per_channel_quant
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Load per-token scale for activations
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale
=
tl
.
load
(
a_scale_ptrs
,
mask
=
token_mask
,
other
=
0.0
)[:,
None
]
# tensor-wise
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
...
...
@@ -414,7 +431,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension.
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
...
...
@@ -426,7 +443,11 @@ def fused_moe_kernel(
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
if
use_fp8_w8a8
:
# acc used to enable fp8_fast_accum
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
...
...
@@ -440,7 +461,7 @@ def fused_moe_kernel(
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
...
...
@@ -471,28 +492,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
...
...
@@ -619,7 +628,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
per_channel_quant
=
per_channel_quant
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
**
config
,
)
...
...
@@ -981,8 +992,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -995,9 +1008,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
def
inplace_fused_experts_fake
(
...
...
@@ -1009,8 +1023,10 @@ def inplace_fused_experts_fake(
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1029,6 +1045,7 @@ direct_register_custom_op(
op_func
=
inplace_fused_experts
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
...
...
@@ -1041,8 +1058,10 @@ def outplace_fused_experts(
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1055,7 +1074,8 @@ def outplace_fused_experts(
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
...
...
@@ -1069,8 +1089,10 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1089,6 +1111,7 @@ direct_register_custom_op(
op_func
=
outplace_fused_experts
,
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
...
...
@@ -1119,8 +1142,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1160,8 +1185,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
...
...
@@ -1174,6 +1201,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe
=
use_nn_moe
)
def
moe_kernel_prepare_input
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# activation channel-wise int8 quantization
assert
(
per_channel_quant
),
"int8 quantization only supports block or channel-wise"
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
# activation block-wise int8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
return
A
,
A_scale
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -1183,8 +1263,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1294,14 +1376,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
else
:
qcurr_hidden_states
=
curr_hidden_states
a1q_scale
=
a1_scale
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
A
=
curr_hidden_states
,
B
=
w1
,
A_scale
=
a1_scale
,
B_scale
=
w1_scale
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
...
...
@@ -1310,7 +1395,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
intermediate_cache1
,
a1
q
_scale
,
q
a1_scale
,
w1_scale
,
w1_zp
,
curr_topk_weights
,
...
...
@@ -1322,8 +1407,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
...
...
@@ -1336,19 +1423,22 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
intermediate_cache2
,
a2_scale
,
block_shape
)
else
:
qintermediate_cache2
=
intermediate_cache2
a2q_scale
=
a2_scale
qintermediate_cache2
,
qa2_scale
=
moe_kernel_prepare_input
(
A
=
intermediate_cache2
,
B
=
w2
,
A_scale
=
a2_scale
,
B_scale
=
w2_scale
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2
q
_scale
,
q
a2_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
...
...
@@ -1360,8 +1450,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
...
...
@@ -1385,8 +1477,10 @@ def fused_moe(
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1420,6 +1514,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
...
...
@@ -1466,8 +1562,10 @@ def fused_moe(
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
9c4ecf15
...
...
@@ -192,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
)
def
forward_cuda
(
...
...
@@ -458,7 +458,7 @@ class FusedMoE(torch.nn.Module):
# Use expert parallelism instead of tensor parallelism?
vllm_config
=
get_current_vllm_config
()
use_ep
=
(
vllm_config
.
parallel_config
.
enable_expert_parallel
and
self
.
tp_size
>
1
)
and
self
.
tp_size
*
self
.
dp_size
>
1
)
# For smuggling this layer into the fused moe custom op
self
.
use_direct_call
=
self
.
dp_size
==
1
...
...
@@ -542,7 +542,9 @@ class FusedMoE(torch.nn.Module):
}
# need full intermediate size pre-sharding for WNA16 act order
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)):
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)):
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
...
...
@@ -691,9 +693,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
)
else
loaded_weight
if
self
.
quant_method
.
__class__
.
__name__
in
(
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
):
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
...
...
vllm/model_executor/layers/linear.py
View file @
9c4ecf15
...
...
@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix
=
f
"
{
prefix
}
.kv_proj_encoder"
)
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self
.
q_size
=
self
.
q_proj_decoder
.
output_size_per_partition
self
.
kv_size
=
self
.
kv_proj_encoder
.
num_kv_heads
*
head_size
if
bias
:
...
...
@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
else
:
self
.
bias
=
None
def
process_weights_after_loading
(
self
):
for
layer
in
self
.
proj
.
values
():
if
self
.
quant_method
is
not
None
:
self
.
quant_method
.
process_weights_after_loading
(
layer
)
@
property
def
q_proj_decoder
(
self
)
->
ColumnParallelLinear
:
layer
=
self
.
proj
[
"q_proj_decoder"
]
for
name
,
param
in
self
.
named_parameters
():
target_param
=
getattr
(
layer
,
name
)
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"q_proj_decoder"
)
target_param
=
getattr
(
layer
,
name
,
None
)
if
target_param
is
not
None
:
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"q_proj_decoder"
)
return
layer
@
property
def
kv_proj_encoder
(
self
)
->
QKVParallelLinear
:
layer
=
self
.
proj
[
"kv_proj_encoder"
]
for
name
,
param
in
self
.
named_parameters
():
target_param
=
getattr
(
layer
,
name
)
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"kv_proj_encoder"
)
target_param
=
getattr
(
layer
,
name
,
None
)
if
target_param
is
not
None
:
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"kv_proj_encoder"
)
return
layer
def
sync_weight_attrs
(
...
...
@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
if
loaded_shard_id
==
"q"
else
self
.
kv_proj_encoder
)
target_param
=
self
.
select_proj_params
(
layer
,
param
)
shard_id_args
=
(
loaded_shard_id
,
)
if
loaded_shard_id
!=
"q"
else
()
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
:
layer
.
weight_loader_v2
(
target_param
,
loaded_weight
,
*
shard_id_args
)
else
:
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", q_size=
{
self
.
q_
proj_decoder
.
output_size_per_partition
}
"
s
+=
f
", q_size=
{
self
.
q_
size
}
"
s
+=
f
", kv_size=
{
self
.
kv_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment