Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31330101
Commit
31330101
authored
Apr 16, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-dev
parents
e8933c34
dc1b4a6f
Changes
346
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
466 additions
and
149 deletions
+466
-149
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+50
-26
vllm/entrypoints/cli/benchmark/base.py
vllm/entrypoints/cli/benchmark/base.py
+1
-0
vllm/entrypoints/cli/benchmark/main.py
vllm/entrypoints/cli/benchmark/main.py
+1
-0
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+5
-2
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+2
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+26
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-4
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-1
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-5
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+1
-1
vllm/entrypoints/openai/tool_parsers/utils.py
vllm/entrypoints/openai/tool_parsers/utils.py
+1
-1
vllm/envs.py
vllm/envs.py
+17
-0
vllm/lora/layers.py
vllm/lora/layers.py
+5
-0
vllm/lora/models.py
vllm/lora/models.py
+1
-1
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
+1
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+8
-7
vllm/model_executor/guided_decoding/utils.py
vllm/model_executor/guided_decoding/utils.py
+0
-4
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+17
-1
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
...used_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
+146
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+170
-93
No files found.
vllm/entrypoints/chat_utils.py
View file @
31330101
...
@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
...
@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
from
vllm.multimodal.utils
import
MediaConnector
from
vllm.multimodal.utils
import
MediaConnector
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
...
@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
_model_config
=
model_config
self
.
_model_config
=
model_config
self
.
_tokenizer
=
tokenizer
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
self
.
_items_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
self
.
_items_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
...
@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def
allowed_local_media_path
(
self
):
def
allowed_local_media_path
(
self
):
return
self
.
_model_config
.
allowed_local_media_path
return
self
.
_model_config
.
allowed_local_media_path
@
property
def
mm_registry
(
self
):
return
MULTIMODAL_REGISTRY
@
staticmethod
@
staticmethod
@
cache
@
cache
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
...
@@ -487,8 +489,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -487,8 +489,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
"<|endoftext10|>"
# 200010 (see vocab.json in hf model)
return
"<|endoftext10|>"
# 200010 (see vocab.json in hf model)
if
model_type
in
(
"minicpmo"
,
"minicpmv"
):
if
model_type
in
(
"minicpmo"
,
"minicpmv"
):
return
"(<image>./</image>)"
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"fuyu"
,
"paligemma"
,
"pixtral"
,
if
model_type
in
(
"blip-2"
,
"florence2"
,
"fuyu"
,
"paligemma"
,
"mistral3"
):
"pixtral"
,
"mistral3"
):
# These models do not use image tokens in the prompt
# These models do not use image tokens in the prompt
return
None
return
None
if
model_type
==
"qwen"
:
if
model_type
==
"qwen"
:
...
@@ -498,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -498,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config
.
image_token_index
)
hf_config
.
image_token_index
)
if
model_type
in
(
"aya_vision"
,
"chameleon"
,
"deepseek_vl_v2"
,
if
model_type
in
(
"aya_vision"
,
"chameleon"
,
"deepseek_vl_v2"
,
"internvl_chat"
,
"skywork_chat"
,
"NVLM_D"
,
"internvl_chat"
,
"skywork_chat"
,
"NVLM_D"
,
"h2ovl_chat"
):
"h2ovl_chat"
,
"idefics3"
,
"smolvlm"
):
return
"<image>"
return
"<image>"
if
model_type
in
(
"mllama"
,
"llama4"
):
if
model_type
in
(
"mllama"
,
"llama4"
):
return
"<|image|>"
return
"<|image|>"
...
@@ -506,8 +508,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -506,8 +508,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
"<|vision_start|><|image_pad|><|vision_end|>"
return
"<|vision_start|><|image_pad|><|vision_end|>"
if
model_type
==
"molmo"
:
if
model_type
==
"molmo"
:
return
""
return
""
if
model_type
==
"idefics3"
:
return
"<image>"
if
model_type
==
"aria"
:
if
model_type
==
"aria"
:
return
"<|fim_prefix|><|img|><|fim_suffix|>"
return
"<|fim_prefix|><|img|><|fim_suffix|>"
if
model_type
==
"gemma3"
:
if
model_type
==
"gemma3"
:
...
@@ -542,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -542,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
Add a multi-modal item to the current prompt and returns the
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
placeholder string to use, if any.
"""
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
mm_registry
=
self
.
mm_registry
model_config
=
self
.
model_config
input_modality
=
modality
.
replace
(
"_embeds"
,
""
)
if
mm_registry
.
has_processor
(
model_config
):
mm_processor
=
mm_registry
.
create_processor
(
model_config
)
allowed_counts
=
mm_processor
.
info
.
get_allowed_mm_limits
()
allowed_count
=
allowed_counts
.
get
(
input_modality
,
0
)
else
:
mm_config
=
model_config
.
multimodal_config
if
mm_config
is
None
:
msg
=
"This model does not support multi-modal inputs"
raise
ValueError
(
msg
)
allowed_count
=
mm_config
.
get_limit_per_prompt
(
input_modality
)
current_count
=
len
(
self
.
_items_by_modality
[
modality
])
+
1
current_count
=
len
(
self
.
_items_by_modality
[
modality
])
+
1
if
current_count
>
allowed_count
:
if
current_count
>
allowed_count
:
raise
ValueError
(
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it."
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
...
@@ -874,19 +891,19 @@ MM_PARSER_MAP: dict[
...
@@ -874,19 +891,19 @@ MM_PARSER_MAP: dict[
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
]
=
{
"text"
:
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
"image_url"
:
"image_url"
:
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
None
),
"image_embeds"
:
"image_embeds"
:
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
{}
),
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
None
),
"audio_url"
:
"audio_url"
:
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
None
),
"input_audio"
:
"input_audio"
:
lambda
part
:
_InputAudioParser
(
part
).
get
(
"input_audio"
,
{}
),
lambda
part
:
_InputAudioParser
(
part
).
get
(
"input_audio"
,
None
),
"refusal"
:
"refusal"
:
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
""
),
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
None
),
"video_url"
:
"video_url"
:
lambda
part
:
_VideoParser
(
part
).
get
(
"video_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_VideoParser
(
part
).
get
(
"video_url"
,
{}).
get
(
"url"
,
None
),
}
}
...
@@ -1005,11 +1022,11 @@ def _parse_chat_message_content_part(
...
@@ -1005,11 +1022,11 @@ def _parse_chat_message_content_part(
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is
empty
, log a warning and skip
# content is
None
, log a warning and skip
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
not
content
:
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
content
is
None
:
logger
.
warning
(
logger
.
warning
(
"Skipping multimodal part (type: '%s') "
"Skipping multimodal part
'%s'
(type: '%s') "
"with empty / unparsable content."
,
part_type
)
"with empty / unparsable content."
,
part
,
part_type
)
return
None
return
None
if
part_type
in
(
"text"
,
"refusal"
):
if
part_type
in
(
"text"
,
"refusal"
):
...
@@ -1195,8 +1212,15 @@ def apply_mistral_chat_template(
...
@@ -1195,8 +1212,15 @@ def apply_mistral_chat_template(
**
kwargs
,
**
kwargs
,
)
)
return
tokenizer
.
apply_chat_template
(
try
:
messages
=
messages
,
return
tokenizer
.
apply_chat_template
(
tools
=
tools
,
messages
=
messages
,
**
kwargs
,
tools
=
tools
,
)
**
kwargs
,
)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
except
AssertionError
as
e
:
raise
ValueError
from
e
vllm/entrypoints/cli/benchmark/base.py
View file @
31330101
...
@@ -32,6 +32,7 @@ class BenchmarkSubcommandBase(CLISubcommand):
...
@@ -32,6 +32,7 @@ class BenchmarkSubcommandBase(CLISubcommand):
parser
=
subparsers
.
add_parser
(
parser
=
subparsers
.
add_parser
(
self
.
name
,
self
.
name
,
help
=
self
.
help
,
help
=
self
.
help
,
description
=
self
.
help
,
usage
=
f
"vllm bench
{
self
.
name
}
[options]"
)
usage
=
f
"vllm bench
{
self
.
name
}
[options]"
)
self
.
add_cli_args
(
parser
)
self
.
add_cli_args
(
parser
)
return
parser
return
parser
vllm/entrypoints/cli/benchmark/main.py
View file @
31330101
...
@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
...
@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser
=
subparsers
.
add_parser
(
bench_parser
=
subparsers
.
add_parser
(
"bench"
,
"bench"
,
help
=
"vLLM bench subcommand."
,
help
=
"vLLM bench subcommand."
,
description
=
"vLLM bench subcommand."
,
usage
=
"vllm bench <bench_type> [options]"
)
usage
=
"vllm bench <bench_type> [options]"
)
bench_subparsers
=
bench_parser
.
add_subparsers
(
required
=
True
,
bench_subparsers
=
bench_parser
.
add_subparsers
(
required
=
True
,
dest
=
"bench_type"
)
dest
=
"bench_type"
)
...
...
vllm/entrypoints/cli/openai.py
View file @
31330101
...
@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
...
@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
chat_parser
=
subparsers
.
add_parser
(
chat_parser
=
subparsers
.
add_parser
(
"chat"
,
"chat"
,
help
=
"Generate chat completions via the running API server"
,
help
=
"Generate chat completions via the running API server."
,
description
=
"Generate chat completions via the running API server."
,
usage
=
"vllm chat [options]"
)
usage
=
"vllm chat [options]"
)
_add_query_options
(
chat_parser
)
_add_query_options
(
chat_parser
)
chat_parser
.
add_argument
(
chat_parser
.
add_argument
(
...
@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
...
@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser
=
subparsers
.
add_parser
(
complete_parser
=
subparsers
.
add_parser
(
"complete"
,
"complete"
,
help
=
(
"Generate text completions based on the given prompt "
help
=
(
"Generate text completions based on the given prompt "
"via the running API server"
),
"via the running API server."
),
description
=
(
"Generate text completions based on the given prompt "
"via the running API server."
),
usage
=
"vllm complete [options]"
)
usage
=
"vllm complete [options]"
)
_add_query_options
(
complete_parser
)
_add_query_options
(
complete_parser
)
return
complete_parser
return
complete_parser
...
...
vllm/entrypoints/cli/serve.py
View file @
31330101
...
@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
...
@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
serve_parser
=
subparsers
.
add_parser
(
serve_parser
=
subparsers
.
add_parser
(
"serve"
,
"serve"
,
help
=
"Start the vLLM OpenAI Compatible API server"
,
help
=
"Start the vLLM OpenAI Compatible API server."
,
description
=
"Start the vLLM OpenAI Compatible API server."
,
usage
=
"vllm serve [model_tag] [options]"
)
usage
=
"vllm serve [model_tag] [options]"
)
serve_parser
.
add_argument
(
"model_tag"
,
serve_parser
.
add_argument
(
"model_tag"
,
type
=
str
,
type
=
str
,
...
...
vllm/entrypoints/llm.py
View file @
31330101
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import
cloudpickle
import
cloudpickle
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
.auto
import
tqdm
from
typing_extensions
import
TypeVar
,
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
@@ -117,6 +117,9 @@ class LLM:
...
@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
HuggingFace config.
...
@@ -177,6 +180,7 @@ class LLM:
...
@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture
:
int
=
8192
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
# After positional args are removed, move this right below `model`
# After positional args are removed, move this right below `model`
...
@@ -232,6 +236,7 @@ class LLM:
...
@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture
=
max_seq_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
disable_async_output_proc
=
disable_async_output_proc
,
hf_token
=
hf_token
,
hf_overrides
=
hf_overrides
,
hf_overrides
=
hf_overrides
,
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
override_pooler_config
=
override_pooler_config
,
override_pooler_config
=
override_pooler_config
,
...
@@ -531,6 +536,16 @@ class LLM:
...
@@ -531,6 +536,16 @@ class LLM:
tokenizer
.
eos_token_id
,
tokenizer
.
eos_token_id
,
length_penalty
)
length_penalty
)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if
any
(
"multi_modal_data"
in
prompt
and
prompt
[
"multi_modal_data"
]
is
not
None
for
prompt
in
prompts
):
logger
.
warning
(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!"
)
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# following the huggingface transformers implementation
...
@@ -906,6 +921,11 @@ class LLM:
...
@@ -906,6 +921,11 @@ class LLM:
if
pooling_params
is
None
:
if
pooling_params
is
None
:
# Use default pooling params.
# Use default pooling params.
pooling_params
=
PoolingParams
()
pooling_params
=
PoolingParams
()
elif
isinstance
(
pooling_params
,
PoolingParams
):
pooling_params
.
verify
(
self
.
llm_engine
.
model_config
)
else
:
for
pooling_param
in
pooling_params
:
pooling_param
.
verify
(
self
.
llm_engine
.
model_config
)
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
prompts
=
parsed_prompts
,
...
@@ -924,6 +944,8 @@ class LLM:
...
@@ -924,6 +944,8 @@ class LLM:
/
,
/
,
*
,
*
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
list
[
EmbeddingRequestOutput
]:
)
->
list
[
EmbeddingRequestOutput
]:
...
@@ -938,6 +960,8 @@ class LLM:
...
@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
...
@@ -953,6 +977,7 @@ class LLM:
...
@@ -953,6 +977,7 @@ class LLM:
items
=
self
.
encode
(
prompts
,
items
=
self
.
encode
(
prompts
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
pooling_params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
31330101
...
@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema
=
self
.
response_format
.
json_schema
json_schema
=
self
.
response_format
.
json_schema
assert
json_schema
is
not
None
assert
json_schema
is
not
None
self
.
guided_json
=
json_schema
.
json_schema
self
.
guided_json
=
json_schema
.
json_schema
if
self
.
guided_decoding_backend
is
None
:
self
.
guided_decoding_backend
=
"xgrammar"
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
...
@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
...
@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
# doc: end-embedding-extra-params
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
...
@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
...
@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return
data
return
data
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
31330101
...
@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
...
@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
truncate_tool_call_ids
,
validate_request_params
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls`
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
truncate_tool_call_ids
(
request
)
validate_request_params
(
request
)
if
(
request
.
tool_choice
==
"auto"
and
if
(
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
31330101
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return
error_check_ret
return
error_check_ret
encoding_format
=
request
.
encoding_format
encoding_format
=
request
.
encoding_format
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
"greater than max_model_len."
" Please, select a smaller truncation size."
)
" Please, select a smaller truncation size."
)
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
try
:
try
:
(
(
lora_request
,
lora_request
,
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
try
:
pooling_params
=
request
.
to_pooling_params
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
31330101
...
@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
...
@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class
PythonicToolParser
(
ToolParser
):
class
PythonicToolParser
(
ToolParser
):
"""
"""
Tool call parser for models that produce tool calls in a pythonic style,
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models.
such as Llama 3.2
and Llama 4
models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
"""
...
...
vllm/entrypoints/openai/tool_parsers/utils.py
View file @
31330101
...
@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
...
@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and
# partial_json_parser doesn't support extra data and
# JSONDeco
r
der.raw_decode doesn't support partial JSON
# JSONDecoder.raw_decode doesn't support partial JSON
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
tuple
[
Any
,
int
]:
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
tuple
[
Any
,
int
]:
try
:
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
...
...
vllm/envs.py
View file @
31330101
...
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
...
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_SPEC_DECODE_EAGER
:
bool
=
False
VLLM_SPEC_DECODE_EAGER
:
bool
=
False
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
VLLM_FLASH_ATTN_VERSION
:
Optional
[
int
]
=
None
VLLM_FLASH_ATTN_VERSION
:
Optional
[
int
]
=
None
LOCAL_RANK
:
int
=
0
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
...
@@ -115,6 +116,7 @@ if TYPE_CHECKING:
...
@@ -115,6 +116,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_XGRAMMAR_CACHE_MB
:
int
=
0
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -289,6 +291,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -289,6 +291,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# when using the flash-attention backend.
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION"
:
"VLLM_FLASH_ATTN_VERSION"
:
lambda
:
maybe_convert_int
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_VERSION"
,
None
)),
lambda
:
maybe_convert_int
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_VERSION"
,
None
)),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
,
"-1"
)),
# Internal flag to enable Dynamo fullgraph capture
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
...
@@ -716,6 +722,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -716,6 +722,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
# Use model_redirect to redirect the model name to a local folder.
# Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH"
:
"VLLM_MODEL_REDIRECT_PATH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_MODEL_REDIRECT_PATH"
,
None
),
lambda
:
os
.
environ
.
get
(
"VLLM_MODEL_REDIRECT_PATH"
,
None
),
...
@@ -743,6 +754,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -743,6 +754,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops.
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM"
:
"VLLM_USE_DEEP_GEMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_XGRAMMAR_CACHE_MB"
,
"512"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/lora/layers.py
View file @
31330101
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and
len
(
packed_modules_list
)
==
3
)
and
len
(
packed_modules_list
)
==
3
)
#TODO: Implement this
class
QKVCrossParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
pass
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
...
...
vllm/lora/models.py
View file @
31330101
...
@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
self
.
model
.
lora_manager
=
self
self
.
adapter_type
=
'LoR
a
'
self
.
adapter_type
=
'LoR
A
'
@
property
@
property
def
capacity
(
self
)
->
int
:
def
capacity
(
self
)
->
int
:
...
...
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
View file @
31330101
...
@@ -111,7 +111,7 @@ class LoRAKernelMeta:
...
@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora
# active_lora_ids, num_tokens_per_lora
lora_ids
,
num_tokens_per_lora
=
torch
.
unique
(
token_lora_mapping
,
lora_ids
,
num_tokens_per_lora
=
torch
.
unique
(
token_lora_mapping
,
sorted
=
Fals
e
,
sorted
=
Tru
e
,
return_counts
=
True
)
return_counts
=
True
)
self
.
active_lora_ids
[:
lora_ids
.
size
(
0
)].
copy_
(
lora_ids
,
self
.
active_lora_ids
[:
lora_ids
.
size
(
0
)].
copy_
(
lora_ids
,
non_blocking
=
True
)
non_blocking
=
True
)
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
31330101
...
@@ -33,6 +33,12 @@ def maybe_backend_fallback(
...
@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger
.
warning
(
"%s Falling back to use %s instead."
,
message
,
fallback
)
logger
.
warning
(
"%s Falling back to use %s instead."
,
message
,
fallback
)
guided_params
.
backend
=
fallback
guided_params
.
backend
=
fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if
guided_params
.
backend
==
"auto"
:
guided_params
.
backend
=
"xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if
guided_params
.
backend_name
==
"lm-format-enforcer"
:
if
guided_params
.
backend_name
==
"lm-format-enforcer"
:
if
guided_params
.
grammar
is
not
None
:
if
guided_params
.
grammar
is
not
None
:
...
@@ -53,14 +59,9 @@ def maybe_backend_fallback(
...
@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
xgr_installed
)
xgr_installed
)
# xgrammar doesn't support regex, fallback to outlines
if
guided_params
.
regex
is
not
None
:
fallback_or_error
(
guided_params
,
"xgrammar does not support regex guided decoding."
,
"outlines"
)
# xgrammar doesn't support some JSON schema features
# xgrammar doesn't support some JSON schema features
el
if
(
guided_params
.
json
is
not
None
if
(
guided_params
.
json
is
not
None
and
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
fallback_or_error
(
fallback_or_error
(
guided_params
,
guided_params
,
"xgrammar does not support advanced JSON schema features like "
"xgrammar does not support advanced JSON schema features like "
...
...
vllm/model_executor/guided_decoding/utils.py
View file @
31330101
...
@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
...
@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if
"pattern"
in
obj
:
if
"pattern"
in
obj
:
return
True
return
True
# Check for enum restrictions
if
"enum"
in
obj
:
return
True
# Check for numeric ranges
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
key
in
obj
for
key
in
[
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
31330101
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import
torch
import
torch
import
vllm.envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
try
:
try
:
...
@@ -131,8 +132,13 @@ class GrammarCompilerCache:
...
@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab
=
config_data
.
encoded_vocab
,
encoded_vocab
=
config_data
.
encoded_vocab
,
metadata
=
config_data
.
metadata
,
metadata
=
config_data
.
metadata
,
)
)
cache_size
=
vllm
.
envs
.
VLLM_XGRAMMAR_CACHE_MB
*
1024
*
1024
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
config
.
max_threads
)
tokenizer_info
,
max_threads
=
config
.
max_threads
,
cache_enabled
=
True
,
cache_limit_bytes
=
cache_size
,
)
return
cls
.
_cache
[
cache_key
]
return
cls
.
_cache
[
cache_key
]
...
@@ -146,6 +152,7 @@ class GrammarConfig:
...
@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str
:
str
|
None
=
None
grammar_str
:
str
|
None
=
None
json_object
:
bool
|
None
=
None
json_object
:
bool
|
None
=
None
any_whitespace
:
bool
=
True
any_whitespace
:
bool
=
True
regex_str
:
str
|
None
=
None
max_threads
:
int
=
8
max_threads
:
int
=
8
@
classmethod
@
classmethod
...
@@ -249,6 +256,13 @@ class GrammarConfig:
...
@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads
=
max_threads
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
tokenizer_data
=
tokenizer_data
,
)
)
elif
guided_params
.
regex
:
return
cls
(
regex_str
=
guided_params
.
regex
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Currently only support JSON and EBNF grammar mode for xgrammar"
"Currently only support JSON and EBNF grammar mode for xgrammar"
...
@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
...
@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self
.
ctx
=
compiler
\
self
.
ctx
=
compiler
\
.
compile_json_schema
(
'{"type": "object"}'
,
.
compile_json_schema
(
'{"type": "object"}'
,
any_whitespace
=
any_whitespace
)
any_whitespace
=
any_whitespace
)
elif
self
.
config
.
regex_str
:
self
.
ctx
=
compiler
.
compile_regex
(
self
.
config
.
regex_str
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid configuration for xgrammar logits processor"
)
"Invalid configuration for xgrammar logits processor"
)
...
...
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json
0 → 100644
View file @
31330101
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
31330101
...
@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
...
@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -479,51 +482,53 @@ def fused_moe_kernel_gptq_awq(
...
@@ -479,51 +482,53 @@ def fused_moe_kernel_gptq_awq(
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel
(
def
fused_moe_kernel
(
# Pointers to matrices
# Pointers to matrices
a_ptr
,
a_ptr
,
b_ptr
,
b_ptr
,
c_ptr
,
c_ptr
,
a_scale_ptr
,
a_scale_ptr
,
b_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
# Matrix dimensions
N
,
N
,
K
,
K
,
EM
,
EM
,
num_valid_tokens
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
# (A has M rows).
stride_am
,
stride_am
,
stride_ak
,
stride_ak
,
stride_be
,
stride_be
,
stride_bk
,
stride_bk
,
stride_bn
,
stride_bn
,
stride_cm
,
stride_cm
,
stride_cn
,
stride_cn
,
stride_asm
,
stride_asm
,
stride_ask
,
stride_ask
,
stride_bse
,
stride_bse
,
stride_bsk
,
stride_bsk
,
stride_bsn
,
stride_bsn
,
# Block size for block-wise quantization
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
use_int8_w8a16
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
):
"""
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
token and expert matrices.
...
@@ -605,11 +610,22 @@ def fused_moe_kernel(
...
@@ -605,11 +610,22 @@ def fused_moe_kernel(
b_scale
=
tl
.
load
(
b_scale_ptrs
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
or
use_int8_w8a8
:
if
use_fp8_w8a8
or
use_int8_w8a8
:
# block-wise
if
group_k
>
0
and
group_n
>
0
:
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
offs_bsn
*
stride_bsn
)
# channel-wise
elif
per_channel_quant
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Load per-token scale for activations
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale
=
tl
.
load
(
a_scale_ptrs
,
mask
=
token_mask
,
other
=
0.0
)[:,
None
]
# tensor-wise
else
:
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
...
@@ -645,7 +661,11 @@ def fused_moe_kernel(
...
@@ -645,7 +661,11 @@ def fused_moe_kernel(
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
None
]
*
b_scale
[
None
,
:]
else
:
else
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
if
use_fp8_w8a8
:
# acc used to enable fp8_fast_accum
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
else
:
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
# Advance the ptrs to the next K block.
...
@@ -693,33 +713,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -693,33 +713,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
num_tokens
=
M
*
top_k
...
@@ -887,7 +887,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -887,7 +887,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
#BLOCK_SIZE_K=BLOCK_SIZE_K,
per_channel_quant
=
per_channel_quant
,
# BLOCK_SIZE_K=BLOCK_SIZE_K,
**
config
,
**
config
,
)
)
...
@@ -1263,6 +1264,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1263,6 +1264,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1275,9 +1277,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1275,9 +1277,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
per_channel_quant
,
global_num_experts
,
expert_map
,
a2_scale
,
block_shape
,
use_nn_moe
)
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1292,6 +1295,7 @@ def inplace_fused_experts_fake(
...
@@ -1292,6 +1295,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1310,6 +1314,7 @@ direct_register_custom_op(
...
@@ -1310,6 +1314,7 @@ direct_register_custom_op(
op_func
=
inplace_fused_experts
,
op_func
=
inplace_fused_experts
,
mutates_args
=
[
"hidden_states"
],
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_fake
,
fake_impl
=
inplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
)
...
@@ -1325,6 +1330,7 @@ def outplace_fused_experts(
...
@@ -1325,6 +1330,7 @@ def outplace_fused_experts(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1337,7 +1343,8 @@ def outplace_fused_experts(
...
@@ -1337,7 +1343,8 @@ def outplace_fused_experts(
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
)
...
@@ -1354,6 +1361,7 @@ def outplace_fused_experts_fake(
...
@@ -1354,6 +1361,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1372,6 +1380,7 @@ direct_register_custom_op(
...
@@ -1372,6 +1380,7 @@ direct_register_custom_op(
op_func
=
outplace_fused_experts
,
op_func
=
outplace_fused_experts
,
mutates_args
=
[],
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_fake
,
fake_impl
=
outplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
)
...
@@ -1405,6 +1414,7 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1405,6 +1414,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1448,6 +1458,7 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1448,6 +1458,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
@@ -1460,6 +1471,59 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1460,6 +1471,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
def
moe_kernel_prepare_input
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# activation channel-wise int8 quantization
assert
(
per_channel_quant
),
"int8 quantization only supports block or channel-wise"
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
# activation block-wise int8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
return
A
,
A_scale
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -1472,6 +1536,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1472,6 +1536,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1583,15 +1648,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1583,15 +1648,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
else
:
qcurr_hidden_states
=
curr_hidden_states
a1q_scale
=
a1_scale
if
use_int8_w8a8
:
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
m
=
curr_hidden_states
.
shape
[
0
]
...
@@ -1620,6 +1676,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1620,6 +1676,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_warps"
:
4
"num_warps"
:
4
}
}
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
A
=
curr_hidden_states
,
B
=
w1
,
A_scale
=
a1_scale
,
B_scale
=
w1_scale
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
if
use_int4_w4a16
:
if
use_int4_w4a16
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
...
@@ -1632,7 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1632,7 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
w1
,
intermediate_cache1
,
intermediate_cache1
,
a1
q
_scale
,
q
a1_scale
,
w1_scale
,
w1_scale
,
w1_zp
,
w1_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1647,6 +1715,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1647,6 +1715,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
...
@@ -1658,15 +1727,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1658,15 +1727,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1
.
view
(
-
1
,
N
))
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qintermediate_cache2
,
qa2_scale
=
moe_kernel_prepare_input
(
A
=
intermediate_cache2
,
if
use_fp8_w8a8
:
B
=
w2
,
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
A_scale
=
a2_scale
,
intermediate_cache2
,
a2_scale
,
block_shape
)
B_scale
=
w2_scale
,
else
:
use_fp8_w8a8
=
use_fp8_w8a8
,
qintermediate_cache2
=
intermediate_cache2
use_int8_w8a8
=
use_int8_w8a8
,
a2q_scale
=
a2_scale
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
if
use_int8_w8a8
:
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
m
=
curr_hidden_states
.
shape
[
0
]
...
@@ -1698,7 +1770,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1698,7 +1770,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
a2
q
_scale
,
q
a2_scale
,
w2_scale
,
w2_scale
,
w2_zp
,
w2_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1713,6 +1785,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1713,6 +1785,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
...
@@ -1739,6 +1812,7 @@ def fused_moe(
...
@@ -1739,6 +1812,7 @@ def fused_moe(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1772,6 +1846,8 @@ def fused_moe(
...
@@ -1772,6 +1846,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
activation to compute the inner products for w1 and w2.
Defaults to False.
Defaults to False.
...
@@ -1821,6 +1897,7 @@ def fused_moe(
...
@@ -1821,6 +1897,7 @@ def fused_moe(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment