Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31330101
Commit
31330101
authored
Apr 16, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-dev
parents
e8933c34
dc1b4a6f
Changes
346
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
466 additions
and
149 deletions
+466
-149
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+50
-26
vllm/entrypoints/cli/benchmark/base.py
vllm/entrypoints/cli/benchmark/base.py
+1
-0
vllm/entrypoints/cli/benchmark/main.py
vllm/entrypoints/cli/benchmark/main.py
+1
-0
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+5
-2
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+2
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+26
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-4
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-1
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-5
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+1
-1
vllm/entrypoints/openai/tool_parsers/utils.py
vllm/entrypoints/openai/tool_parsers/utils.py
+1
-1
vllm/envs.py
vllm/envs.py
+17
-0
vllm/lora/layers.py
vllm/lora/layers.py
+5
-0
vllm/lora/models.py
vllm/lora/models.py
+1
-1
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
+1
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+8
-7
vllm/model_executor/guided_decoding/utils.py
vllm/model_executor/guided_decoding/utils.py
+0
-4
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+17
-1
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
...used_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
+146
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+170
-93
No files found.
vllm/entrypoints/chat_utils.py
View file @
31330101
...
...
@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
from
vllm.multimodal.utils
import
MediaConnector
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
...
...
@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
_model_config
=
model_config
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
self
.
_items_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
...
...
@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def
allowed_local_media_path
(
self
):
return
self
.
_model_config
.
allowed_local_media_path
@
property
def
mm_registry
(
self
):
return
MULTIMODAL_REGISTRY
@
staticmethod
@
cache
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
...
...
@@ -487,8 +489,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
"<|endoftext10|>"
# 200010 (see vocab.json in hf model)
if
model_type
in
(
"minicpmo"
,
"minicpmv"
):
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"fuyu"
,
"paligemma"
,
"pixtral"
,
"mistral3"
):
if
model_type
in
(
"blip-2"
,
"florence2"
,
"fuyu"
,
"paligemma"
,
"pixtral"
,
"mistral3"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
==
"qwen"
:
...
...
@@ -498,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config
.
image_token_index
)
if
model_type
in
(
"aya_vision"
,
"chameleon"
,
"deepseek_vl_v2"
,
"internvl_chat"
,
"skywork_chat"
,
"NVLM_D"
,
"h2ovl_chat"
):
"h2ovl_chat"
,
"idefics3"
,
"smolvlm"
):
return
"<image>"
if
model_type
in
(
"mllama"
,
"llama4"
):
return
"<|image|>"
...
...
@@ -506,8 +508,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
"<|vision_start|><|image_pad|><|vision_end|>"
if
model_type
==
"molmo"
:
return
""
if
model_type
==
"idefics3"
:
return
"<image>"
if
model_type
==
"aria"
:
return
"<|fim_prefix|><|img|><|fim_suffix|>"
if
model_type
==
"gemma3"
:
...
...
@@ -542,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
mm_registry
=
self
.
mm_registry
model_config
=
self
.
model_config
input_modality
=
modality
.
replace
(
"_embeds"
,
""
)
if
mm_registry
.
has_processor
(
model_config
):
mm_processor
=
mm_registry
.
create_processor
(
model_config
)
allowed_counts
=
mm_processor
.
info
.
get_allowed_mm_limits
()
allowed_count
=
allowed_counts
.
get
(
input_modality
,
0
)
else
:
mm_config
=
model_config
.
multimodal_config
if
mm_config
is
None
:
msg
=
"This model does not support multi-modal inputs"
raise
ValueError
(
msg
)
allowed_count
=
mm_config
.
get_limit_per_prompt
(
input_modality
)
current_count
=
len
(
self
.
_items_by_modality
[
modality
])
+
1
if
current_count
>
allowed_count
:
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it."
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
...
...
@@ -874,19 +891,19 @@ MM_PARSER_MAP: dict[
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
"image_url"
:
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
None
),
"image_embeds"
:
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
{}
),
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
None
),
"audio_url"
:
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
None
),
"input_audio"
:
lambda
part
:
_InputAudioParser
(
part
).
get
(
"input_audio"
,
{}
),
lambda
part
:
_InputAudioParser
(
part
).
get
(
"input_audio"
,
None
),
"refusal"
:
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
""
),
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
None
),
"video_url"
:
lambda
part
:
_VideoParser
(
part
).
get
(
"video_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_VideoParser
(
part
).
get
(
"video_url"
,
{}).
get
(
"url"
,
None
),
}
...
...
@@ -1005,11 +1022,11 @@ def _parse_chat_message_content_part(
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is
empty
, log a warning and skip
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
not
content
:
# content is
None
, log a warning and skip
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
content
is
None
:
logger
.
warning
(
"Skipping multimodal part (type: '%s') "
"with empty / unparsable content."
,
part_type
)
"Skipping multimodal part
'%s'
(type: '%s') "
"with empty / unparsable content."
,
part
,
part_type
)
return
None
if
part_type
in
(
"text"
,
"refusal"
):
...
...
@@ -1195,8 +1212,15 @@ def apply_mistral_chat_template(
**
kwargs
,
)
return
tokenizer
.
apply_chat_template
(
messages
=
messages
,
tools
=
tools
,
**
kwargs
,
)
try
:
return
tokenizer
.
apply_chat_template
(
messages
=
messages
,
tools
=
tools
,
**
kwargs
,
)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
except
AssertionError
as
e
:
raise
ValueError
from
e
vllm/entrypoints/cli/benchmark/base.py
View file @
31330101
...
...
@@ -32,6 +32,7 @@ class BenchmarkSubcommandBase(CLISubcommand):
parser
=
subparsers
.
add_parser
(
self
.
name
,
help
=
self
.
help
,
description
=
self
.
help
,
usage
=
f
"vllm bench
{
self
.
name
}
[options]"
)
self
.
add_cli_args
(
parser
)
return
parser
vllm/entrypoints/cli/benchmark/main.py
View file @
31330101
...
...
@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser
=
subparsers
.
add_parser
(
"bench"
,
help
=
"vLLM bench subcommand."
,
description
=
"vLLM bench subcommand."
,
usage
=
"vllm bench <bench_type> [options]"
)
bench_subparsers
=
bench_parser
.
add_subparsers
(
required
=
True
,
dest
=
"bench_type"
)
...
...
vllm/entrypoints/cli/openai.py
View file @
31330101
...
...
@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
chat_parser
=
subparsers
.
add_parser
(
"chat"
,
help
=
"Generate chat completions via the running API server"
,
help
=
"Generate chat completions via the running API server."
,
description
=
"Generate chat completions via the running API server."
,
usage
=
"vllm chat [options]"
)
_add_query_options
(
chat_parser
)
chat_parser
.
add_argument
(
...
...
@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser
=
subparsers
.
add_parser
(
"complete"
,
help
=
(
"Generate text completions based on the given prompt "
"via the running API server"
),
"via the running API server."
),
description
=
(
"Generate text completions based on the given prompt "
"via the running API server."
),
usage
=
"vllm complete [options]"
)
_add_query_options
(
complete_parser
)
return
complete_parser
...
...
vllm/entrypoints/cli/serve.py
View file @
31330101
...
...
@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
serve_parser
=
subparsers
.
add_parser
(
"serve"
,
help
=
"Start the vLLM OpenAI Compatible API server"
,
help
=
"Start the vLLM OpenAI Compatible API server."
,
description
=
"Start the vLLM OpenAI Compatible API server."
,
usage
=
"vllm serve [model_tag] [options]"
)
serve_parser
.
add_argument
(
"model_tag"
,
type
=
str
,
...
...
vllm/entrypoints/llm.py
View file @
31330101
...
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import
cloudpickle
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
.auto
import
tqdm
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
...
@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
...
...
@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
# After positional args are removed, move this right below `model`
...
...
@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
hf_token
=
hf_token
,
hf_overrides
=
hf_overrides
,
mm_processor_kwargs
=
mm_processor_kwargs
,
override_pooler_config
=
override_pooler_config
,
...
...
@@ -531,6 +536,16 @@ class LLM:
tokenizer
.
eos_token_id
,
length_penalty
)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if
any
(
"multi_modal_data"
in
prompt
and
prompt
[
"multi_modal_data"
]
is
not
None
for
prompt
in
prompts
):
logger
.
warning
(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!"
)
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
...
...
@@ -906,6 +921,11 @@ class LLM:
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
elif
isinstance
(
pooling_params
,
PoolingParams
):
pooling_params
.
verify
(
self
.
llm_engine
.
model_config
)
else
:
for
pooling_param
in
pooling_params
:
pooling_param
.
verify
(
self
.
llm_engine
.
model_config
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
...
...
@@ -924,6 +944,8 @@ class LLM:
/
,
*
,
use_tqdm
:
bool
=
True
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
list
[
EmbeddingRequestOutput
]:
...
...
@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
...
...
@@ -953,6 +977,7 @@ class LLM:
items
=
self
.
encode
(
prompts
,
use_tqdm
=
use_tqdm
,
pooling_params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
31330101
...
...
@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema
=
self
.
response_format
.
json_schema
assert
json_schema
is
not
None
self
.
guided_json
=
json_schema
.
json_schema
if
self
.
guided_decoding_backend
is
None
:
self
.
guided_decoding_backend
=
"xgrammar"
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
...
...
@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
...
...
@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return
data
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
31330101
...
...
@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
truncate_tool_call_ids
,
validate_request_params
)
logger
=
init_logger
(
__name__
)
...
...
@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
validate_request_params
(
request
)
if
(
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
31330101
...
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return
error_check_ret
encoding_format
=
request
.
encoding_format
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
...
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
" Please, select a smaller truncation size."
)
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
try
:
(
lora_request
,
...
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
pooling_params
=
request
.
to_pooling_params
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
31330101
...
...
@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class
PythonicToolParser
(
ToolParser
):
"""
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models.
such as Llama 3.2
and Llama 4
models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
...
...
vllm/entrypoints/openai/tool_parsers/utils.py
View file @
31330101
...
...
@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and
# JSONDeco
r
der.raw_decode doesn't support partial JSON
# JSONDecoder.raw_decode doesn't support partial JSON
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
tuple
[
Any
,
int
]:
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
...
...
vllm/envs.py
View file @
31330101
...
...
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_SPEC_DECODE_EAGER
:
bool
=
False
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
VLLM_FLASH_ATTN_VERSION
:
Optional
[
int
]
=
None
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
...
...
@@ -115,6 +116,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_XGRAMMAR_CACHE_MB
:
int
=
0
def
get_default_cache_root
():
...
...
@@ -289,6 +291,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION"
:
lambda
:
maybe_convert_int
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_VERSION"
,
None
)),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
,
"-1"
)),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
...
...
@@ -716,6 +722,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
# Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_MODEL_REDIRECT_PATH"
,
None
),
...
...
@@ -743,6 +754,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_XGRAMMAR_CACHE_MB"
,
"512"
)),
}
# end-env-vars-definition
...
...
vllm/lora/layers.py
View file @
31330101
...
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and
len
(
packed_modules_list
)
==
3
)
#TODO: Implement this
class
QKVCrossParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
pass
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
...
...
vllm/lora/models.py
View file @
31330101
...
...
@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
self
.
adapter_type
=
'LoR
a
'
self
.
adapter_type
=
'LoR
A
'
@
property
def
capacity
(
self
)
->
int
:
...
...
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
View file @
31330101
...
...
@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora
lora_ids
,
num_tokens_per_lora
=
torch
.
unique
(
token_lora_mapping
,
sorted
=
Fals
e
,
sorted
=
Tru
e
,
return_counts
=
True
)
self
.
active_lora_ids
[:
lora_ids
.
size
(
0
)].
copy_
(
lora_ids
,
non_blocking
=
True
)
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
31330101
...
...
@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger
.
warning
(
"%s Falling back to use %s instead."
,
message
,
fallback
)
guided_params
.
backend
=
fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if
guided_params
.
backend
==
"auto"
:
guided_params
.
backend
=
"xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if
guided_params
.
backend_name
==
"lm-format-enforcer"
:
if
guided_params
.
grammar
is
not
None
:
...
...
@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
xgr_installed
)
# xgrammar doesn't support regex, fallback to outlines
if
guided_params
.
regex
is
not
None
:
fallback_or_error
(
guided_params
,
"xgrammar does not support regex guided decoding."
,
"outlines"
)
# xgrammar doesn't support some JSON schema features
el
if
(
guided_params
.
json
is
not
None
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
if
(
guided_params
.
json
is
not
None
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
fallback_or_error
(
guided_params
,
"xgrammar does not support advanced JSON schema features like "
...
...
vllm/model_executor/guided_decoding/utils.py
View file @
31330101
...
...
@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if
"pattern"
in
obj
:
return
True
# Check for enum restrictions
if
"enum"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
31330101
...
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import
torch
import
vllm.envs
from
vllm.logger
import
init_logger
try
:
...
...
@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab
=
config_data
.
encoded_vocab
,
metadata
=
config_data
.
metadata
,
)
cache_size
=
vllm
.
envs
.
VLLM_XGRAMMAR_CACHE_MB
*
1024
*
1024
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
config
.
max_threads
)
tokenizer_info
,
max_threads
=
config
.
max_threads
,
cache_enabled
=
True
,
cache_limit_bytes
=
cache_size
,
)
return
cls
.
_cache
[
cache_key
]
...
...
@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str
:
str
|
None
=
None
json_object
:
bool
|
None
=
None
any_whitespace
:
bool
=
True
regex_str
:
str
|
None
=
None
max_threads
:
int
=
8
@
classmethod
...
...
@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
)
elif
guided_params
.
regex
:
return
cls
(
regex_str
=
guided_params
.
regex
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
)
else
:
raise
ValueError
(
"Currently only support JSON and EBNF grammar mode for xgrammar"
...
...
@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self
.
ctx
=
compiler
\
.
compile_json_schema
(
'{"type": "object"}'
,
any_whitespace
=
any_whitespace
)
elif
self
.
config
.
regex_str
:
self
.
ctx
=
compiler
.
compile_regex
(
self
.
config
.
regex_str
)
else
:
raise
ValueError
(
"Invalid configuration for xgrammar logits processor"
)
...
...
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
0 → 100644
View file @
31330101
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
31330101
...
...
@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -479,51 +482,53 @@ def fused_moe_kernel_gptq_awq(
@
triton
.
jit
def
fused_moe_kernel
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
...
...
@@ -605,11 +610,22 @@ def fused_moe_kernel(
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
or
use_int8_w8a8
:
# block-wise
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
# channel-wise
elif
per_channel_quant
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Load per-token scale for activations
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale
=
tl
.
load
(
a_scale_ptrs
,
mask
=
token_mask
,
other
=
0.0
)[:,
None
]
# tensor-wise
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
...
...
@@ -645,7 +661,11 @@ def fused_moe_kernel(
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
if
use_fp8_w8a8
:
# acc used to enable fp8_fast_accum
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
...
...
@@ -693,33 +713,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
...
...
@@ -887,7 +887,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
#BLOCK_SIZE_K=BLOCK_SIZE_K,
per_channel_quant
=
per_channel_quant
,
# BLOCK_SIZE_K=BLOCK_SIZE_K,
**
config
,
)
...
...
@@ -1263,6 +1264,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1275,9 +1277,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
def
inplace_fused_experts_fake
(
...
...
@@ -1292,6 +1295,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1310,6 +1314,7 @@ direct_register_custom_op(
op_func
=
inplace_fused_experts
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
...
...
@@ -1325,6 +1330,7 @@ def outplace_fused_experts(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1337,7 +1343,8 @@ def outplace_fused_experts(
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
...
...
@@ -1354,6 +1361,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1372,6 +1380,7 @@ direct_register_custom_op(
op_func
=
outplace_fused_experts
,
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
...
...
@@ -1405,6 +1414,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1448,6 +1458,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
...
...
@@ -1460,6 +1471,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe
=
use_nn_moe
)
def
moe_kernel_prepare_input
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# activation channel-wise int8 quantization
assert
(
per_channel_quant
),
"int8 quantization only supports block or channel-wise"
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
# activation block-wise int8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
return
A
,
A_scale
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -1472,6 +1536,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1583,15 +1648,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
else
:
qcurr_hidden_states
=
curr_hidden_states
a1q_scale
=
a1_scale
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
...
...
@@ -1620,6 +1676,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_warps"
:
4
}
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
A
=
curr_hidden_states
,
B
=
w1
,
A_scale
=
a1_scale
,
B_scale
=
w1_scale
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
if
use_int4_w4a16
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
...
...
@@ -1632,7 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
intermediate_cache1
,
a1
q
_scale
,
q
a1_scale
,
w1_scale
,
w1_zp
,
curr_topk_weights
,
...
...
@@ -1647,6 +1715,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
...
...
@@ -1658,15 +1727,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
intermediate_cache2
,
a2_scale
,
block_shape
)
else
:
qintermediate_cache2
=
intermediate_cache2
a2q_scale
=
a2_scale
qintermediate_cache2
,
qa2_scale
=
moe_kernel_prepare_input
(
A
=
intermediate_cache2
,
B
=
w2
,
A_scale
=
a2_scale
,
B_scale
=
w2_scale
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
...
...
@@ -1698,7 +1770,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2
q
_scale
,
q
a2_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
...
...
@@ -1713,6 +1785,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
...
...
@@ -1739,6 +1812,7 @@ def fused_moe(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1772,6 +1846,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
...
...
@@ -1821,6 +1897,7 @@ def fused_moe(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment