Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1fbdf957
Commit
1fbdf957
authored
Feb 26, 2025
by
王敏
Browse files
[feat]合入ds3 MTP功能
parent
537e2d9c
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
950 additions
and
121 deletions
+950
-121
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+35
-16
tests/models/registry.py
tests/models/registry.py
+35
-17
tests/spec_decode/e2e/test_mtp_correctness.py
tests/spec_decode/e2e/test_mtp_correctness.py
+318
-0
vllm/config.py
vllm/config.py
+105
-55
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+5
-2
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+326
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+16
-8
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+27
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+11
-2
vllm/sequence.py
vllm/sequence.py
+2
-0
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+15
-5
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+30
-8
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+14
-2
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+4
-1
vllm/worker/worker.py
vllm/worker/worker.py
+3
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+4
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
1fbdf957
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# adding a new command to an existing step. See different options here for examples.
# adding a new command to an existing step. See different options here for examples.
# This script will be feed into Jinja template in `test-template-aws.j2` at
# This script will be feed into Jinja template in `test-template-aws.j2` at
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
# to generate the final pipeline yaml file.
# to generate the final pipeline yaml file.
# Documentation
# Documentation
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
# in this case, commands must be specified. the first command runs on first host, the second
# in this case, commands must be specified. the first command runs on first host, the second
# command runs on the second host.
# command runs on the second host.
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
...
@@ -24,8 +24,8 @@
...
@@ -24,8 +24,8 @@
# When adding a test
# When adding a test
# - If the test belong to an existing group, add it there
# - If the test belong to an existing group, add it there
# - If the test is short, add to any existing step
# - If the test is short, add to any existing step
# - If the test takes more than 10min, then it is okay to create a new step.
# - If the test takes more than 10min, then it is okay to create a new step.
# Note that all steps execute in parallel.
# Note that all steps execute in parallel.
steps
:
steps
:
##### fast check tests #####
##### fast check tests #####
...
@@ -107,13 +107,17 @@ steps:
...
@@ -107,13 +107,17 @@ steps:
mirror_hardwares
:
[
amd
]
mirror_hardwares
:
[
amd
]
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
-
tests/entrypoints/llm
-
tests/entrypoints/openai
-
tests/entrypoints/test_chat_utils
-
tests/entrypoints/offline_mode
commands
:
commands
:
-
pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
-
pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
-
pytest -v -s entrypoints/llm/test_lazy_outlines.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_lazy_outlines.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate_multiple_loras.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_generate_multiple_loras.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_guided_generate.py
# it needs a clean process
-
pytest -v -s entrypoints/llm/test_guided_generate.py
# it needs a clean process
-
pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
-
pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
--ignore=entrypoints/openai/correctness/
-
pytest -v -s entrypoints/test_chat_utils.py
-
pytest -v -s entrypoints/test_chat_utils.py
-
pytest -v -s entrypoints/offline_mode
# Needs to avoid interference with other tests
-
pytest -v -s entrypoints/offline_mode
# Needs to avoid interference with other tests
...
@@ -124,11 +128,12 @@ steps:
...
@@ -124,11 +128,12 @@ steps:
source_file_dependencies
:
source_file_dependencies
:
-
vllm/distributed/
-
vllm/distributed/
-
vllm/core/
-
vllm/core/
-
tests/distributed
-
tests/distributed/test_utils
-
tests/distributed/test_pynccl
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/compile
-
tests/compile
/test_basic_correctness
-
examples/offline_inference/rlhf.py
-
examples/offline_inference/rlhf.py
-
examples/offline_inference/r
ay_placement
.py
-
examples/offline_inference/r
lhf_colocate
.py
commands
:
commands
:
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s compile/test_basic_correctness.py
...
@@ -137,17 +142,17 @@ steps:
...
@@ -137,17 +142,17 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
# when we have multiple distributed example tests
-
python3 ../examples/offline_inference/rlhf.py
-
python3 ../examples/offline_inference/rlhf.py
-
RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/r
ay_placement
.py
-
RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/r
lhf_colocate
.py
-
label
:
Metrics, Tracing Test
# 10min
-
label
:
Metrics, Tracing Test
# 10min
num_gpus
:
2
num_gpus
:
2
fast_check
:
true
fast_check
:
true
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
-
tests/metrics
-
tests/metrics
-
tests/tracing
-
tests/tracing
commands
:
commands
:
-
pytest -v -s metrics
-
pytest -v -s metrics
-
"
pip
install
\
-
"
pip
install
\
'opentelemetry-sdk>=1.26.0,<1.27.0'
\
'opentelemetry-sdk>=1.26.0,<1.27.0'
\
'opentelemetry-api>=1.26.0,<1.27.0'
\
'opentelemetry-api>=1.26.0,<1.27.0'
\
...
@@ -174,6 +179,9 @@ steps:
...
@@ -174,6 +179,9 @@ steps:
-
vllm/
-
vllm/
-
tests/engine
-
tests/engine
-
tests/tokenization
-
tests/tokenization
-
tests/test_sequence
-
tests/test_config
-
tests/test_logger
commands
:
commands
:
-
pytest -v -s engine test_sequence.py test_config.py test_logger.py
-
pytest -v -s engine test_sequence.py test_config.py test_logger.py
# OOM in the CI unless we run this separately
# OOM in the CI unless we run this separately
...
@@ -195,6 +203,9 @@ steps:
...
@@ -195,6 +203,9 @@ steps:
# TODO: accuracy does not match, whether setting
# TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
-
VLLM_USE_V1=1 pytest -v -s v1/e2e
-
VLLM_USE_V1=1 pytest -v -s v1/e2e
# Integration test for streaming correctness (requires special branch).
-
pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
-
pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
-
label
:
Examples Test
# 25min
-
label
:
Examples Test
# 25min
working_dir
:
"
/vllm-workspace/examples"
working_dir
:
"
/vllm-workspace/examples"
...
@@ -243,7 +254,7 @@ steps:
...
@@ -243,7 +254,7 @@ steps:
-
vllm/model_executor/guided_decoding
-
vllm/model_executor/guided_decoding
-
tests/test_logits_processor
-
tests/test_logits_processor
-
tests/model_executor/test_guided_processors
-
tests/model_executor/test_guided_processors
commands
:
commands
:
-
pytest -v -s test_logits_processor.py
-
pytest -v -s test_logits_processor.py
-
pytest -v -s model_executor/test_guided_processors.py
-
pytest -v -s model_executor/test_guided_processors.py
...
@@ -254,7 +265,7 @@ steps:
...
@@ -254,7 +265,7 @@ steps:
-
vllm/model_executor/models/eagle.py
-
vllm/model_executor/models/eagle.py
commands
:
commands
:
-
pytest -v -s spec_decode/e2e/test_multistep_correctness.py
-
pytest -v -s spec_decode/e2e/test_multistep_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
--ignore=spec_decode/e2e/test_mtp_correctness.py
-
pytest -v -s spec_decode/e2e/test_eagle_correctness.py
-
pytest -v -s spec_decode/e2e/test_eagle_correctness.py
-
label
:
LoRA Test %N
# 15min each
-
label
:
LoRA Test %N
# 15min each
...
@@ -328,6 +339,14 @@ steps:
...
@@ -328,6 +339,14 @@ steps:
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
bash ./run-tests.sh -c configs/models-small.txt -t
1
-
bash ./run-tests.sh -c configs/models-small.txt -t
1
-
label
:
OpenAI API correctness
source_file_dependencies
:
-
csrc/
-
vllm/entrypoints/openai/
-
vllm/model_executor/models/whisper.py
commands
:
# LMEval+Transcription WER check
-
pytest -s entrypoints/openai/correctness/
-
label
:
Encoder Decoder tests
# 5min
-
label
:
Encoder Decoder tests
# 5min
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
...
@@ -561,7 +580,7 @@ steps:
...
@@ -561,7 +580,7 @@ steps:
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
# This test runs llama 13B, so it is required to run on 4 GPUs.
# This test runs llama 13B, so it is required to run on 4 GPUs.
-
pytest -v -s -x lora/test_long_context.py
-
pytest -v -s -x lora/test_long_context.py
# There is some Tensor Parallelism related processing logic in LoRA that
# There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation.
# requires multi-GPU testing for validation.
-
pytest -v -s -x lora/test_chatglm3_tp.py
-
pytest -v -s -x lora/test_chatglm3_tp.py
-
pytest -v -s -x lora/test_llama_tp.py
-
pytest -v -s -x lora/test_llama_tp.py
...
@@ -586,7 +605,7 @@ steps:
...
@@ -586,7 +605,7 @@ steps:
-
vllm/
-
vllm/
-
tests/weight_loading
-
tests/weight_loading
commands
:
commands
:
-
bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
-
bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
##### multi gpus test #####
##### multi gpus test #####
...
@@ -598,7 +617,7 @@ steps:
...
@@ -598,7 +617,7 @@ steps:
num_gpus
:
4
num_gpus
:
4
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
commands
:
commands
:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
# see https://github.com/vllm-project/vllm/pull/5689 for details
-
pytest -v -s distributed/test_custom_all_reduce.py
-
pytest -v -s distributed/test_custom_all_reduce.py
...
...
tests/models/registry.py
View file @
1fbdf957
...
@@ -102,8 +102,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -102,8 +102,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"BambaForCausalLM"
:
_HfExamplesInfo
(
"ibm-ai-platform/Bamba-9B"
),
"BloomForCausalLM"
:
_HfExamplesInfo
(
"bigscience/bloomz-1b1"
),
"BloomForCausalLM"
:
_HfExamplesInfo
(
"bigscience/bloomz-1b1"
),
# ChatGLMModel supports multimodal
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/chatglm3-6b"
,
trust_remote_code
=
True
),
"CohereForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r-v01"
,
"CohereForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r-v01"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Cohere2ForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r7b-12-2024"
,
# noqa: E501
"Cohere2ForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r7b-12-2024"
,
# noqa: E501
...
@@ -137,11 +139,14 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -137,11 +139,14 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"InternLM3ForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm3-8b-instruct"
,
"InternLM3ForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm3-8b-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"JAISLMHeadModel"
:
_HfExamplesInfo
(
"inceptionai/jais-13b-chat"
),
"JAISLMHeadModel"
:
_HfExamplesInfo
(
"inceptionai/jais-13b-chat"
),
"JambaForCausalLM"
:
_HfExamplesInfo
(
"ai21labs/AI21-Jamba-1.5-Mini"
),
"JambaForCausalLM"
:
_HfExamplesInfo
(
"ai21labs/AI21-Jamba-1.5-Mini"
,
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Meta-Llama-3-8B"
),
extras
=
{
"tiny"
:
"ai21labs/Jamba-tiny-dev"
}),
# noqa: E501
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-1B-Instruct"
),
"LLaMAForCausalLM"
:
_HfExamplesInfo
(
"decapoda-research/llama-7b-hf"
,
"LLaMAForCausalLM"
:
_HfExamplesInfo
(
"decapoda-research/llama-7b-hf"
,
is_available_online
=
False
),
is_available_online
=
False
),
"MambaForCausalLM"
:
_HfExamplesInfo
(
"state-spaces/mamba-130m-hf"
),
"MambaForCausalLM"
:
_HfExamplesInfo
(
"state-spaces/mamba-130m-hf"
),
"Mamba2ForCausalLM"
:
_HfExamplesInfo
(
"mistralai/Mamba-Codestral-7B-v0.1"
,
is_available_online
=
False
),
"FalconMambaForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-mamba-7b-instruct"
),
# noqa: E501
"FalconMambaForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-mamba-7b-instruct"
),
# noqa: E501
"MiniCPMForCausalLM"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-2B-sft-bf16"
,
"MiniCPMForCausalLM"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-2B-sft-bf16"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
@@ -166,7 +171,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -166,7 +171,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"PhiMoEForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3.5-MoE-instruct"
,
"PhiMoEForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3.5-MoE-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
# QWenLMHeadModel supports multimodal
"QWenLMHeadModel"
:
_HfExamplesInfo
(
"Qwen/Qwen-7B-Chat"
,
trust_remote_code
=
True
),
"Qwen2ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen2-7B-Instruct"
),
"Qwen2ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen2-7B-Instruct"
),
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"RWForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-40b"
,
"RWForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-40b"
,
...
@@ -213,6 +219,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -213,6 +219,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"TIGER-Lab/VLM2Vec-Full"
,
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"TIGER-Lab/VLM2Vec-Full"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"MrLight/dse-qwen2-2b-mrl-v1"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"MrLight/dse-qwen2-2b-mrl-v1"
),
# noqa: E501
# The model on Huggingface is currently being updated,
# hence I temporarily mark it as not available online
"PrithviGeoSpatialMAE"
:
_HfExamplesInfo
(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
,
# noqa: E501
is_available_online
=
False
),
}
}
_CROSS_ENCODER_EXAMPLE_MODELS
=
{
_CROSS_ENCODER_EXAMPLE_MODELS
=
{
...
@@ -227,18 +237,19 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -227,18 +237,19 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
),
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
),
"Blip2ForConditionalGeneration"
:
_HfExamplesInfo
(
"Salesforce/blip2-opt-2.7b"
),
# noqa: E501
"Blip2ForConditionalGeneration"
:
_HfExamplesInfo
(
"Salesforce/blip2-opt-2.7b"
),
# noqa: E501
"ChameleonForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/chameleon-7b"
),
# noqa: E501
"ChameleonForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/chameleon-7b"
),
# noqa: E501
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
extras
=
{
"text_only"
:
"THUDM/chatglm3-6b"
},
trust_remote_code
=
True
),
"ChatGLMForConditionalGeneration"
:
_HfExamplesInfo
(
"chatglm2-6b"
,
is_available_online
=
False
),
"DeepseekVLV2ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/deepseek-vl2-tiny"
,
# noqa: E501
"DeepseekVLV2ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/deepseek-vl2-tiny"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}),
# noqa: E501
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
),
"GLM4VForCausalLM"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
trust_remote_code
=
True
,
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]}),
# noqa: E501
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
,
extras
=
{
"2b"
:
"h2oai/h2ovl-mississippi-2b"
}),
# noqa: E501
"InternVLChatModel"
:
_HfExamplesInfo
(
"OpenGVLab/InternVL2-1B"
,
"InternVLChatModel"
:
_HfExamplesInfo
(
"OpenGVLab/InternVL2-1B"
,
extras
=
{
"2B"
:
"OpenGVLab/InternVL2-2B"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Idefics3ForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceM4/Idefics3-8B-Llama3"
),
# noqa: E501
"Idefics3ForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceM4/Idefics3-8B-Llama3"
,
# noqa: E501
{
"tiny"
:
"HuggingFaceTB/SmolVLM-256M-Instruct"
}),
# noqa: E501
"LlavaForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-1.5-7b-hf"
,
"LlavaForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-1.5-7b-hf"
,
extras
=
{
"mistral"
:
"mistral-community/pixtral-12b"
}),
# noqa: E501
extras
=
{
"mistral"
:
"mistral-community/pixtral-12b"
}),
# noqa: E501
"LlavaNextForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-v1.6-mistral-7b-hf"
),
# noqa: E501
"LlavaNextForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-v1.6-mistral-7b-hf"
),
# noqa: E501
...
@@ -248,25 +259,29 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -248,25 +259,29 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]}),
# noqa: E501
"MiniCPMO"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-o-2_6"
,
"MiniCPMO"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-o-2_6"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-V-2_6"
,
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-Llama3-V-2_5"
,
extras
=
{
"2.6"
:
"openbmb/MiniCPM-V-2_6"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
extras
=
{
"olmo"
:
"allenai/Molmo-7B-O-0924"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"NVLM_D"
:
_HfExamplesInfo
(
"nvidia/NVLM-D-72B"
,
"NVLM_D"
:
_HfExamplesInfo
(
"nvidia/NVLM-D-72B"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"PaliGemmaForConditionalGeneration"
:
_HfExamplesInfo
(
"google/paligemma-3b-pt-224"
),
# noqa: E501
"PaliGemmaForConditionalGeneration"
:
_HfExamplesInfo
(
"google/paligemma-3b-mix-224"
,
# noqa: E501
extras
=
{
"v2"
:
"google/paligemma2-3b-ft-docci-448"
}),
# noqa: E501
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-vision-128k-instruct"
,
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-vision-128k-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"PixtralForConditionalGeneration"
:
_HfExamplesInfo
(
"mistralai/Pixtral-12B-2409"
,
# noqa: E501
"PixtralForConditionalGeneration"
:
_HfExamplesInfo
(
"mistralai/Pixtral-12B-2409"
,
# noqa: E501
tokenizer_mode
=
"mistral"
),
tokenizer_mode
=
"mistral"
),
"QWenLMHeadModel"
:
_HfExamplesInfo
(
"Qwen/Qwen-VL-Chat"
,
"QwenVLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen-VL"
,
extras
=
{
"text_only"
:
"Qwen/Qwen-7B-Chat"
},
# noqa: E501
extras
=
{
"chat"
:
"Qwen/Qwen-VL-Chat"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
,
hf_overrides
=
{
"architectures"
:
[
"QwenVLForConditionalGeneration"
]}),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-Audio-7B-Instruct"
),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-Audio-7B-Instruct"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"Qwen2_5_VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-VL-3B-Instruct"
,
# noqa: E501
"Qwen2_5_VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-VL-3B-Instruct"
,
# noqa: E501
min_transformers_version
=
"4.49"
),
# noqa: E501
min_transformers_version
=
"4.49"
),
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_
3
"
,
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_
5-llama-3_2-1b
"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
# [Encoder-decoder]
# [Encoder-decoder]
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
...
@@ -280,6 +295,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
...
@@ -280,6 +295,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model
=
"abhigoyal/vllm-medusa-llama-68m-random"
),
# noqa: E501
speculative_model
=
"abhigoyal/vllm-medusa-llama-68m-random"
),
# noqa: E501
"MLPSpeculatorPreTrainedModel"
:
_HfExamplesInfo
(
"JackFram/llama-160m"
,
"MLPSpeculatorPreTrainedModel"
:
_HfExamplesInfo
(
"JackFram/llama-160m"
,
speculative_model
=
"ibm-ai-platform/llama-160m-accelerator"
),
# noqa: E501
speculative_model
=
"ibm-ai-platform/llama-160m-accelerator"
),
# noqa: E501
"DeepSeekMTPModel"
:
_HfExamplesInfo
(
"luccafong/deepseek_mtp_main_random"
,
speculative_model
=
"luccafong/deepseek_mtp_draft_random"
,
# noqa: E501
trust_remote_code
=
True
),
}
}
_FALLBACK_MODEL
=
{
_FALLBACK_MODEL
=
{
...
...
tests/spec_decode/e2e/test_mtp_correctness.py
0 → 100644
View file @
1fbdf957
# SPDX-License-Identifier: Apache-2.0
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""
import
pytest
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"luccafong/deepseek_mtp_main_random"
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
1
# precision
PRECISION
=
"bfloat16"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.85
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.85
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs_during_spec_decoding"
:
False
,
},
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs_during_spec_decoding"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_mtp_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"enforce_eager"
:
False
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
"gpu_memory_utilization"
:
0.85
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_e2e_greedy_correctness_cuda_graph
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
8
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.9
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.9
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"num_speculative_tokens"
:
k
,
}
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.9
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
if
__name__
==
"__main__"
:
import
pytest
pytest
.
main
([
__file__
])
vllm/config.py
View file @
1fbdf957
...
@@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
...
@@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
5120
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
5120
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
"score"
,
"reward"
]
"score"
,
"reward"
,
"transcription"
]
_ResolvedTask
=
Literal
[
"generate"
,
"embed"
,
"classify"
,
"score"
,
"reward"
,
_ResolvedTask
=
Literal
[
"generate"
,
"embed"
,
"classify"
,
"score"
,
"reward"
,
"draft"
]
"draft"
,
"transcription"
]
RunnerType
=
Literal
[
"generate"
,
"pooling"
,
"draft"
]
RunnerType
=
Literal
[
"generate"
,
"pooling"
,
"draft"
,
"transcription"
]
_RUNNER_TASKS
:
Dict
[
RunnerType
,
List
[
_ResolvedTask
]]
=
{
_RUNNER_TASKS
:
Dict
[
RunnerType
,
List
[
_ResolvedTask
]]
=
{
"generate"
:
[
"generate"
],
"generate"
:
[
"generate"
],
"pooling"
:
[
"embed"
,
"classify"
,
"score"
,
"reward"
],
"pooling"
:
[
"embed"
,
"classify"
,
"score"
,
"reward"
],
"draft"
:
[
"draft"
],
"draft"
:
[
"draft"
],
"transcription"
:
[
"transcription"
],
}
}
_TASK_RUNNER
:
Dict
[
_ResolvedTask
,
RunnerType
]
=
{
_TASK_RUNNER
:
Dict
[
_ResolvedTask
,
RunnerType
]
=
{
...
@@ -102,8 +103,9 @@ class ModelConfig:
...
@@ -102,8 +103,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use.
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`.
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
allowed_local_media_path: Allowing API requests to read local images or
...
@@ -468,10 +470,10 @@ class ModelConfig:
...
@@ -468,10 +470,10 @@ class ModelConfig:
def
_verify_tokenizer_mode
(
self
)
->
None
:
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
]:
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]:
raise
ValueError
(
raise
ValueError
(
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
"either 'auto', 'slow'
or
'mistral'."
)
"either 'auto', 'slow'
,
'mistral'
or 'custom'
."
)
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_get_preferred_task
(
def
_get_preferred_task
(
...
@@ -484,6 +486,8 @@ class ModelConfig:
...
@@ -484,6 +486,8 @@ class ModelConfig:
return
"embed"
return
"embed"
if
ModelRegistry
.
is_cross_encoder_model
(
architectures
):
if
ModelRegistry
.
is_cross_encoder_model
(
architectures
):
return
"score"
return
"score"
if
ModelRegistry
.
is_transcription_model
(
architectures
):
return
"transcription"
suffix_to_preferred_task
:
List
[
Tuple
[
str
,
_ResolvedTask
]]
=
[
suffix_to_preferred_task
:
List
[
Tuple
[
str
,
_ResolvedTask
]]
=
[
# Other models follow this pattern
# Other models follow this pattern
...
@@ -516,6 +520,8 @@ class ModelConfig:
...
@@ -516,6 +520,8 @@ class ModelConfig:
runner_support
:
Dict
[
RunnerType
,
bool
]
=
{
runner_support
:
Dict
[
RunnerType
,
bool
]
=
{
# NOTE: Listed from highest to lowest priority,
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
# in case the model supports multiple of them
"transcription"
:
ModelRegistry
.
is_transcription_model
(
architectures
),
"generate"
:
ModelRegistry
.
is_text_generation_model
(
architectures
),
"generate"
:
ModelRegistry
.
is_text_generation_model
(
architectures
),
"pooling"
:
ModelRegistry
.
is_pooling_model
(
architectures
),
"pooling"
:
ModelRegistry
.
is_pooling_model
(
architectures
),
}
}
...
@@ -757,7 +763,7 @@ class ModelConfig:
...
@@ -757,7 +763,7 @@ class ModelConfig:
def
is_deepseek_mla
(
self
)
->
bool
:
def
is_deepseek_mla
(
self
)
->
bool
:
return
(
hasattr
(
self
.
hf_text_config
,
"model_type"
))
\
return
(
hasattr
(
self
.
hf_text_config
,
"model_type"
))
\
and
(
self
.
hf_text_config
.
model_type
in
\
and
(
self
.
hf_text_config
.
model_type
in
\
(
'deepseek_v2'
,
'deepseek_v3'
))
\
(
'deepseek_v2'
,
'deepseek_v3'
,
'deepseek_mtp'
))
\
and
(
self
.
hf_text_config
.
kv_lora_rank
is
not
None
)
and
(
self
.
hf_text_config
.
kv_lora_rank
is
not
None
)
def
get_head_size
(
self
)
->
int
:
def
get_head_size
(
self
)
->
int
:
...
@@ -850,8 +856,12 @@ class ModelConfig:
...
@@ -850,8 +856,12 @@ class ModelConfig:
def
get_layers_start_end_indices
(
def
get_layers_start_end_indices
(
self
,
parallel_config
:
"ParallelConfig"
)
->
Tuple
[
int
,
int
]:
self
,
parallel_config
:
"ParallelConfig"
)
->
Tuple
[
int
,
int
]:
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.distributed.utils
import
get_pp_indices
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
if
self
.
hf_text_config
.
model_type
==
"deepseek_mtp"
:
"num_hidden_layers"
,
0
)
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
else
:
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_hidden_layers"
,
0
)
pp_rank
=
parallel_config
.
rank
//
parallel_config
.
tensor_parallel_size
pp_rank
=
parallel_config
.
rank
//
parallel_config
.
tensor_parallel_size
pp_size
=
parallel_config
.
pipeline_parallel_size
pp_size
=
parallel_config
.
pipeline_parallel_size
start
,
end
=
get_pp_indices
(
total_num_hidden_layers
,
pp_rank
,
pp_size
)
start
,
end
=
get_pp_indices
(
total_num_hidden_layers
,
pp_rank
,
pp_size
)
...
@@ -986,37 +996,7 @@ class ModelConfig:
...
@@ -986,37 +996,7 @@ class ModelConfig:
@
property
@
property
def
use_mla
(
self
)
->
bool
:
def
use_mla
(
self
)
->
bool
:
if
not
self
.
is_deepseek_mla
or
envs
.
VLLM_MLA_DISABLE
:
return
self
.
is_deepseek_mla
and
not
envs
.
VLLM_MLA_DISABLE
return
False
if
self
.
quantization
is
not
None
and
self
.
quantization
not
in
[
\
"fp8"
,
"compressed-tensors"
]:
logger
.
warning
(
"MLA is not supported with %s quantization. "
"Disabling MLA."
,
self
.
quantization
)
return
False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if
self
.
quantization
==
"compressed-tensors"
:
quant_config
=
self
.
_parse_quant_hf_config
()
for
group_name
,
cfg
in
quant_config
.
get
(
"config_groups"
,
{
""
:
{}
}).
items
():
act_cfg
=
cfg
.
get
(
"input_activations"
,
{})
act_type
=
None
if
act_cfg
is
None
else
act_cfg
.
get
(
"type"
,
""
)
w_cfg
=
cfg
.
get
(
"weights"
,
{})
w_type
=
None
if
w_cfg
is
None
else
w_cfg
.
get
(
"type"
,
""
)
if
act_type
!=
"fp8"
or
w_type
!=
"fp8"
:
logger
.
warning
(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.
\n
"
"Full config: %s"
,
group_name
,
act_type
,
w_type
,
quant_config
)
return
False
return
True
@
property
@
property
def
supported_runner_types
(
self
)
->
Set
[
RunnerType
]:
def
supported_runner_types
(
self
)
->
Set
[
RunnerType
]:
...
@@ -1404,6 +1384,9 @@ class ParallelConfig:
...
@@ -1404,6 +1384,9 @@ class ParallelConfig:
logger
.
info
(
"Defaulting to use %s for distributed inference"
,
logger
.
info
(
"Defaulting to use %s for distributed inference"
,
backend
)
backend
)
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
==
1
:
self
.
distributed_executor_backend
=
"uni"
self
.
_verify_args
()
self
.
_verify_args
()
@
property
@
property
...
@@ -1454,6 +1437,17 @@ class SchedulerConfig:
...
@@ -1454,6 +1437,17 @@ class SchedulerConfig:
# Maximum length of a sequence (including prompt and generated text).
# Maximum length of a sequence (including prompt and generated text).
max_model_len
:
int
=
8192
max_model_len
:
int
=
8192
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills
:
int
=
1
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills
:
int
=
1
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold
:
int
=
0
# The number of slots to allocate per sequence per
# The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative
# step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be
# decoding to store KV activations of tokens which may or may not be
...
@@ -1561,6 +1555,18 @@ class SchedulerConfig:
...
@@ -1561,6 +1555,18 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
)
self
.
max_num_batched_tokens
)
self
.
chunked_prefill_enabled
=
self
.
enable_chunked_prefill
self
.
chunked_prefill_enabled
=
self
.
enable_chunked_prefill
if
self
.
max_num_partial_prefills
>
1
:
if
self
.
long_prefill_token_threshold
==
0
:
self
.
long_prefill_token_threshold
=
int
(
self
.
max_model_len
*
0.04
)
logger
.
info
(
"Concurrent partial prefills enabled with "
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
"long_prefill_token_threshold=%d"
,
self
.
max_num_partial_prefills
,
self
.
max_long_partial_prefills
,
self
.
long_prefill_token_threshold
)
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
@@ -1592,6 +1598,29 @@ class SchedulerConfig:
...
@@ -1592,6 +1598,29 @@ class SchedulerConfig:
f
"(
{
self
.
num_scheduler_steps
}
) must be greater than or "
f
"(
{
self
.
num_scheduler_steps
}
) must be greater than or "
"equal to 1."
)
"equal to 1."
)
if
self
.
max_num_partial_prefills
<
1
:
raise
ValueError
(
f
"max_num_partial_prefills (
{
self
.
max_num_partial_prefills
}
) "
"must be greater than or equal to 1."
)
elif
self
.
max_num_partial_prefills
>
1
:
if
not
self
.
chunked_prefill_enabled
:
raise
ValueError
(
"Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1."
)
if
self
.
long_prefill_token_threshold
>
self
.
max_model_len
:
raise
ValueError
(
"long_prefill_token_threshold "
f
"(
{
self
.
long_prefill_token_threshold
}
) cannot be greater "
f
"than the max_model_len (
{
self
.
max_model_len
}
)."
)
if
(
self
.
max_long_partial_prefills
<
1
)
or
(
self
.
max_long_partial_prefills
>
self
.
max_num_partial_prefills
):
raise
ValueError
(
f
"max_long_partial_prefills (
{
self
.
max_long_partial_prefills
}
) "
"must be greater than or equal to 1 and less than or equal to "
f
"max_num_partial_prefills (
{
self
.
max_num_partial_prefills
}
)."
)
@
property
@
property
def
is_multi_step
(
self
)
->
bool
:
def
is_multi_step
(
self
)
->
bool
:
return
self
.
num_scheduler_steps
>
1
return
self
.
num_scheduler_steps
>
1
...
@@ -1666,6 +1695,18 @@ class SpeculativeConfig:
...
@@ -1666,6 +1695,18 @@ class SpeculativeConfig:
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
return
hash_str
@
staticmethod
def
hf_config_override
(
hf_config
:
PretrainedConfig
)
->
PretrainedConfig
:
if
hf_config
.
model_type
==
"deepseek_v3"
:
hf_config
.
model_type
=
"deepseek_mtp"
if
hf_config
.
model_type
==
"deepseek_mtp"
:
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
({
"n_predict"
:
n_predict
,
"architectures"
:
[
"DeepSeekMTPModel"
]
})
return
hf_config
@
staticmethod
@
staticmethod
def
maybe_create_spec_config
(
def
maybe_create_spec_config
(
target_model_config
:
ModelConfig
,
target_model_config
:
ModelConfig
,
...
@@ -1754,9 +1795,16 @@ class SpeculativeConfig:
...
@@ -1754,9 +1795,16 @@ class SpeculativeConfig:
if
speculative_model
is
None
:
if
speculative_model
is
None
:
if
num_speculative_tokens
is
not
None
:
if
num_speculative_tokens
is
not
None
:
raise
ValueError
(
"num_speculative_tokens was provided without "
if
target_model_config
.
hf_text_config
.
model_type
\
"speculative_model."
)
==
"deepseek_v3"
:
return
None
# use the draft model from the same model:
speculative_model
=
target_model_config
.
model
else
:
raise
ValueError
(
"num_speculative_tokens was provided without "
"speculative_model."
)
else
:
return
None
if
(
speculative_disable_by_batch_size
is
not
None
if
(
speculative_disable_by_batch_size
is
not
None
and
speculative_disable_by_batch_size
<
2
):
and
speculative_disable_by_batch_size
<
2
):
...
@@ -1810,10 +1858,20 @@ class SpeculativeConfig:
...
@@ -1810,10 +1858,20 @@ class SpeculativeConfig:
max_seq_len_to_capture
=
target_model_config
.
max_seq_len_to_capture
=
target_model_config
.
max_seq_len_to_capture
,
max_seq_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
max_logprobs
=
target_model_config
.
max_logprobs
,
hf_overrides
=
SpeculativeConfig
.
hf_config_override
,
)
)
draft_hf_config
=
draft_model_config
.
hf_config
draft_hf_config
=
draft_model_config
.
hf_config
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if
"eagle-"
in
draft_model_config
.
model
.
lower
():
from
vllm.transformers_utils.configs.eagle
import
EAGLEConfig
if
isinstance
(
draft_model_config
.
hf_config
,
EAGLEConfig
):
pass
else
:
eagle_config
=
EAGLEConfig
(
draft_model_config
.
hf_config
)
draft_model_config
.
hf_config
=
eagle_config
if
(
num_speculative_tokens
is
not
None
if
(
num_speculative_tokens
is
not
None
and
hasattr
(
draft_hf_config
,
"num_lookahead_tokens"
)):
and
hasattr
(
draft_hf_config
,
"num_lookahead_tokens"
)):
draft_hf_config
.
num_lookahead_tokens
=
num_speculative_tokens
draft_hf_config
.
num_lookahead_tokens
=
num_speculative_tokens
...
@@ -1935,8 +1993,9 @@ class SpeculativeConfig:
...
@@ -1935,8 +1993,9 @@ class SpeculativeConfig:
speculative_draft_tensor_parallel_size
=
1
speculative_draft_tensor_parallel_size
=
1
if
target_parallel_config
.
tensor_parallel_size
>
1
:
if
target_parallel_config
.
tensor_parallel_size
>
1
:
logger
.
warning
(
logger
.
warning
(
"MLPSpeculator cannot currently be run with tp>1; "
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1"
)
"setting speculative_draft_tensor_parallel_size=1"
,
draft_hf_config
.
model_type
)
else
:
else
:
speculative_draft_tensor_parallel_size
=
\
speculative_draft_tensor_parallel_size
=
\
target_parallel_config
.
tensor_parallel_size
target_parallel_config
.
tensor_parallel_size
...
@@ -3089,15 +3148,6 @@ class VllmConfig:
...
@@ -3089,15 +3148,6 @@ class VllmConfig:
the final hidden states.
the final hidden states.
"""
"""
factors
:
List
[
Any
]
=
[]
factors
:
List
[
Any
]
=
[]
# summarize system state
from
torch._inductor.codecache
import
CacheBase
system_factors
=
CacheBase
.
get_system
()
factors
.
append
(
system_factors
)
# summarize pytorch state
from
torch._inductor.codecache
import
torch_key
torch_factors
=
torch_key
()
factors
.
append
(
torch_factors
)
# summarize vllm config
# summarize vllm config
vllm_factors
:
List
[
Any
]
=
[]
vllm_factors
:
List
[
Any
]
=
[]
...
...
vllm/model_executor/model_loader/utils.py
View file @
1fbdf957
...
@@ -80,8 +80,11 @@ def get_model_architecture(
...
@@ -80,8 +80,11 @@ def get_model_architecture(
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures
=
[
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
support_nn_architectures
=
[
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
]
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTP'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
...
...
vllm/model_executor/models/deepseek_mtp.py
0 → 100644
View file @
1fbdf957
# SPDX-License-Identifier: Apache-2.0
import
os
import
re
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
get_spec_layer_idx_from_weight_name
)
from
.utils
import
maybe_prefix
from
vllm
import
_custom_ops
as
ops
class
SharedHead
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
norm
(
hidden_states
)
class
DeepSeekMultiTokenPredictorLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
prefix
:
str
,
model_config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
,
bias
=
False
)
self
.
shared_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mtp_block
=
DeepseekV2DecoderLayer
(
config
,
prefix
,
model_config
,
cache_config
,
quant_config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
hidden_states
,
residual
=
self
.
mtp_block
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
residual
=
None
)
hidden_states
=
residual
+
hidden_states
return
self
.
shared_head
(
hidden_states
)
class
DeepSeekMultiTokenPredictor
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
mtp_start_layer_idx
=
config
.
num_hidden_layers
self
.
num_mtp_layers
=
config
.
num_nextn_predict_layers
# to map the exact layer index from weights
self
.
layers
=
torch
.
nn
.
ModuleDict
({
str
(
idx
):
DeepSeekMultiTokenPredictorLayer
(
config
,
f
"
{
prefix
}
.layers.
{
idx
}
"
,
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
quant_config
=
vllm_config
.
quant_config
,
)
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
self
.
mtp_start_layer_idx
+
self
.
num_mtp_layers
)
})
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
spec_step_idx
)](
input_ids
,
positions
,
kv_caches
[
spec_step_idx
],
attn_metadata
,
previous_hidden_states
,
inputs_embeds
,
spec_step_idx
,
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
mtp_layer
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
spec_step_idx
)]
logits
=
self
.
logits_processor
(
mtp_layer
.
shared_head
.
head
,
hidden_states
,
sampling_metadata
)
return
logits
class
DeepSeekMTP
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
self
.
model
=
DeepSeekMultiTokenPredictor
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
sampler
=
get_sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
previous_hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
previous_hidden_states
,
inputs_embeds
,
spec_step_idx
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
spec_step_idx
:
int
=
0
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
,
spec_step_idx
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
spec_layer
=
get_spec_layer_idx_from_weight_name
(
self
.
config
,
name
)
if
spec_layer
is
None
:
continue
name
=
self
.
_rewrite_spec_layer_name
(
spec_layer
,
name
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attn.eh_proj.weight"
,
"self_attn.q_proj.weight"
,
"self_attn.q_a_proj.weight"
,
"self_attn.q_b_proj.weight"
,
"self_attn.kv_a_proj_with_mqa.weight"
,
"self_attn.kv_b_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
,
"mlp.gate.weight"
,
"shared_experts.gate_up_proj.weight"
,
"shared_experts.down_proj.weight"
,
"shared_head.head.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
return
loaded_params
def
_rewrite_spec_layer_name
(
self
,
spec_layer
:
int
,
name
:
str
)
->
str
:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names
=
[
"embed_tokens"
,
"enorm"
,
"hnorm"
,
"eh_proj"
,
"shared_head"
]
spec_layer_weight
=
False
for
weight_name
in
spec_layer_weight_names
:
if
weight_name
in
name
:
spec_layer_weight
=
True
break
if
not
spec_layer_weight
:
# treat rest weights as weights for transformer layer block
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
f
"model.layers.
{
spec_layer
}
.mtp_block."
)
return
name
vllm/model_executor/models/deepseek_v2.py
View file @
1fbdf957
...
@@ -773,13 +773,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -773,13 +773,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
# TODO(simon): support nextn predict layers
spec_layer
=
get_spec_layer_idx_from_weight_name
(
self
.
config
,
name
)
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
if
spec_layer
is
not
None
:
)
and
self
.
config
.
num_nextn_predict_layers
>
0
:
continue
# skip spec decode layers for main model
assert
self
.
config
.
num_nextn_predict_layers
==
1
layer_idx
=
self
.
config
.
num_hidden_layers
if
name
.
startswith
(
f
"model.layers.
{
layer_idx
}
"
):
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
# Skip non-stacked layers and experts (experts handled below).
...
@@ -927,4 +923,16 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -927,4 +923,16 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
pass
pass
\ No newline at end of file
def
get_spec_layer_idx_from_weight_name
(
config
:
PretrainedConfig
,
weight_name
:
str
)
->
Optional
[
int
]:
if
hasattr
(
config
,
"num_nextn_predict_layers"
)
and
(
config
.
num_nextn_predict_layers
>
0
):
layer_idx
=
config
.
num_hidden_layers
for
i
in
range
(
config
.
num_nextn_predict_layers
):
if
weight_name
.
startswith
(
f
"model.layers.
{
layer_idx
+
i
}
."
):
return
layer_idx
+
i
return
None
vllm/model_executor/models/interfaces.py
View file @
1fbdf957
...
@@ -445,3 +445,30 @@ def supports_cross_encoding(
...
@@ -445,3 +445,30 @@ def supports_cross_encoding(
model
:
Union
[
Type
[
object
],
object
],
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
)
->
Union
[
TypeIs
[
Type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
return
is_pooling_model
(
model
)
and
_supports_cross_encoding
(
model
)
return
is_pooling_model
(
model
)
and
_supports_cross_encoding
(
model
)
@
runtime_checkable
class
SupportsTranscription
(
Protocol
):
"""The interface required for all models that support transcription."""
supports_transcription
:
ClassVar
[
Literal
[
True
]]
=
True
@
overload
def
supports_transcription
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
SupportsTranscription
]]:
...
@
overload
def
supports_transcription
(
model
:
object
)
->
TypeIs
[
SupportsTranscription
]:
...
def
supports_transcription
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
SupportsTranscription
]],
TypeIs
[
SupportsTranscription
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
SupportsTranscription
)
return
isinstance
(
model
,
SupportsTranscription
)
vllm/model_executor/models/registry.py
View file @
1fbdf957
...
@@ -22,7 +22,7 @@ from vllm.logger import init_logger
...
@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
is_hybrid
,
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
supports_multimodal
,
supports_cross_encoding
,
supports_multimodal
,
supports_pp
)
supports_pp
,
supports_transcription
)
from
.interfaces_base
import
is_text_generation_model
from
.interfaces_base
import
is_text_generation_model
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -182,6 +182,7 @@ _MULTIMODAL_MODELS = {
...
@@ -182,6 +182,7 @@ _MULTIMODAL_MODELS = {
_SPECULATIVE_DECODING_MODELS
=
{
_SPECULATIVE_DECODING_MODELS
=
{
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"DeepSeekMTPModel"
:
(
"deepseek_mtp"
,
"DeepSeekMTP"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
}
}
...
@@ -212,6 +213,7 @@ class _ModelInfo:
...
@@ -212,6 +213,7 @@ class _ModelInfo:
has_inner_state
:
bool
has_inner_state
:
bool
is_attention_free
:
bool
is_attention_free
:
bool
is_hybrid
:
bool
is_hybrid
:
bool
supports_transcription
:
bool
@
staticmethod
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
...
@@ -225,7 +227,7 @@ class _ModelInfo:
...
@@ -225,7 +227,7 @@ class _ModelInfo:
has_inner_state
=
has_inner_state
(
model
),
has_inner_state
=
has_inner_state
(
model
),
is_attention_free
=
is_attention_free
(
model
),
is_attention_free
=
is_attention_free
(
model
),
is_hybrid
=
is_hybrid
(
model
),
is_hybrid
=
is_hybrid
(
model
),
)
supports_transcription
=
supports_transcription
(
model
)
)
class
_BaseRegisteredModel
(
ABC
):
class
_BaseRegisteredModel
(
ABC
):
...
@@ -472,6 +474,13 @@ class _ModelRegistry:
...
@@ -472,6 +474,13 @@ class _ModelRegistry:
)
->
bool
:
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_hybrid
return
model_cls
.
is_hybrid
def
is_transcription_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_transcription
ModelRegistry
=
_ModelRegistry
({
ModelRegistry
=
_ModelRegistry
({
...
...
vllm/sequence.py
View file @
1fbdf957
...
@@ -1384,6 +1384,8 @@ class ExecuteModelRequest(
...
@@ -1384,6 +1384,8 @@ class ExecuteModelRequest(
previous_logits
:
Optional
[
Logits
]
=
None
previous_logits
:
Optional
[
Logits
]
=
None
# The number of forward steps to run.
# The number of forward steps to run.
num_steps
:
int
=
1
num_steps
:
int
=
1
# The step index for spec model input.
spec_step_idx
:
Optional
[
int
]
=
None
# Finished request ids since last step.
# Finished request ids since last step.
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
# The last sampled token ids for multi step decoding.
...
...
vllm/spec_decode/draft_model_runner.py
View file @
1fbdf957
...
@@ -153,7 +153,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
...
@@ -153,7 +153,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
return
False
return
False
# TODO: Add support for other attn backends
# TODO: Add support for other attn backends
if
self
.
attn_backend
.
get_name
()
!=
"FLASH_ATTN"
:
if
self
.
attn_backend
.
get_name
()
not
in
(
"FLASH_ATTN"
,
"TRITON_MLA"
)
:
return
False
return
False
# TODO: Add support for LORA
# TODO: Add support for LORA
...
@@ -175,6 +175,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
...
@@ -175,6 +175,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
**
kwargs
,
)
->
Optional
[
List
[
SamplerOutput
]]:
)
->
Optional
[
List
[
SamplerOutput
]]:
"""Executes num_steps forward passes with advacement of input tensors
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
...
@@ -271,10 +272,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
...
@@ -271,10 +272,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
for
step
in
range
(
num_steps
):
for
step
in
range
(
num_steps
):
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
kwargs
=
{
"previous_hidden_states"
:
hidden_states
}
\
model_execute_
kwargs
=
{
"previous_hidden_states"
:
hidden_states
}
\
if
previous_hidden_states
is
not
None
else
{}
if
previous_hidden_states
is
not
None
else
{}
compute_logits_kwargs
=
{}
# Run model
# Run model
if
hasattr
(
self
.
model
.
config
,
"num_nextn_predict_layers"
):
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx
=
kwargs
.
get
(
"spec_step_idx"
,
step
)
model_execute_kwargs
[
"spec_step_idx"
]
=
spec_step_idx
compute_logits_kwargs
[
"spec_step_idx"
]
=
spec_step_idx
with
set_forward_context
(
model_input
.
attn_metadata
,
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
self
.
vllm_config
):
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
...
@@ -285,13 +293,15 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
...
@@ -285,13 +293,15 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
device
=
self
.
device
),
**
kwargs
,
**
model_execute_
kwargs
,
)
)
# Compute the logits.
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
model_input
.
sampling_metadata
,
**
compute_logits_kwargs
)
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
# Sample the next token.
output
=
self
.
model
.
sample
(
output
=
self
.
model
.
sample
(
logits
=
logits
,
logits
=
logits
,
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
1fbdf957
...
@@ -10,7 +10,8 @@ import torch
...
@@ -10,7 +10,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
,
get_tp_group
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
@@ -112,7 +113,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -112,7 +113,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_log_stats
=
speculative_config
.
disable_log_stats
)
disable_log_stats
=
speculative_config
.
disable_log_stats
,
num_speculative_tokens
=
speculative_config
.
num_speculative_tokens
,
)
return
spec_decode_worker
return
spec_decode_worker
...
@@ -157,9 +160,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -157,9 +160,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
disable_log_stats
:
bool
,
num_speculative_tokens
:
int
,
)
->
"SpecDecodeWorker"
:
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
allow_zero_draft_token_step
=
True
num_spec_prefill_steps
=
1
ngram_prompt_lookup_max
=
(
ngram_prompt_lookup_max
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
ngram_prompt_lookup_min
=
(
ngram_prompt_lookup_min
=
(
...
@@ -185,17 +190,21 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -185,17 +190,21 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
elif
draft_model_config
.
hf_config
.
model_type
==
"medusa"
:
elif
draft_model_config
.
hf_config
.
model_type
==
"medusa"
:
proposer_worker
=
MedusaWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
MedusaWorker
(
**
draft_worker_kwargs
)
else
:
else
:
if
draft_tp
==
1
:
if
draft_tp
==
1
or
draft_model_config
.
hf_config
.
model_type
==
\
"deepseek_mtp"
:
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
draft_worker_kwargs
[
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
"model_runner_cls"
]
=
TP1DraftModelRunner
else
:
else
:
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"EAGLE does not support TP > 1 yet"
)
f
"
{
draft_model_config
.
hf_config
.
model_type
}
"
"does not support TP > 1 yet"
)
allow_zero_draft_token_step
=
False
allow_zero_draft_token_step
=
False
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
if
draft_model_config
.
hf_config
.
model_type
==
"deepseek_mtp"
:
num_spec_prefill_steps
=
num_speculative_tokens
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
proposer_worker
,
draft_tp
,
target_tp
)
proposer_worker
,
draft_tp
,
target_tp
)
...
@@ -247,7 +256,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -247,7 +256,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats
=
disable_log_stats
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -260,6 +270,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -260,6 +270,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
num_spec_prefill_steps
:
int
=
1
,
):
):
"""
"""
Create a SpecDecodeWorker.
Create a SpecDecodeWorker.
...
@@ -290,6 +301,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -290,6 +301,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
draft model is larger than 1 (TODO: #5814)
num_spec_prefill_steps: number of speculative prefill steps to run
before the speculative decoding starts. This is only used when
the draft model is a deepseek_mtp model that requires prefill
kv cache separately for each MTP layer.
"""
"""
self
.
proposer_worker
=
proposer_worker
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
scorer_worker
=
scorer_worker
...
@@ -324,6 +339,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -324,6 +339,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
self
.
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_log_stats
=
disable_log_stats
self
.
_disable_log_stats
=
disable_log_stats
self
.
_num_spec_prefill_steps
=
num_spec_prefill_steps
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
...
@@ -340,8 +356,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -340,8 +356,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
proposer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
self
.
_metrics
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
self
.
_metrics
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
self
.
spec_decode_sampler
.
init_tensors
(
self
.
rank
,
if
model_parallel_is_initialized
():
device_type
=
self
.
device
)
self
.
spec_decode_sampler
.
init_tensors
(
get_tp_group
().
local_rank
,
device_type
=
self
.
device
)
else
:
self
.
spec_decode_sampler
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
scorer_cls
:
Type
[
SpeculativeScorer
]
scorer_cls
:
Type
[
SpeculativeScorer
]
if
self
.
disable_mqa_scorer
:
if
self
.
disable_mqa_scorer
:
...
@@ -698,7 +718,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -698,7 +718,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
prepare_prefill_hidden_states
(
prepare_prefill_hidden_states
(
sampler_output
.
prefill_hidden_states
)
sampler_output
.
prefill_hidden_states
)
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
for
i
in
range
(
self
.
_num_spec_prefill_steps
):
execute_model_req
.
spec_step_idx
=
i
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
...
...
vllm/worker/model_runner.py
View file @
1fbdf957
...
@@ -100,6 +100,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -100,6 +100,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
@@ -1652,6 +1653,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1652,6 +1653,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
**
kwargs
,
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
if
num_steps
>
1
:
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in ModelRunner"
)
raise
ValueError
(
"num_steps > 1 is not supported in ModelRunner"
)
...
@@ -1709,6 +1711,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1709,6 +1711,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"finished_requests_ids"
:
model_input
.
finished_requests_ids
,
"finished_requests_ids"
:
model_input
.
finished_requests_ids
,
"request_ids_to_seq_ids"
:
model_input
.
request_ids_to_seq_ids
,
"request_ids_to_seq_ids"
:
model_input
.
request_ids_to_seq_ids
,
}
if
self
.
has_inner_state
else
{}
}
if
self
.
has_inner_state
else
{}
previous_hidden_states
=
kwargs
.
get
(
"previous_hidden_states"
)
model_kwargs
=
{}
if
previous_hidden_states
is
not
None
:
model_kwargs
[
"previous_hidden_states"
]
=
previous_hidden_states
if
(
self
.
observability_config
is
not
None
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
model_forward_start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
...
@@ -1726,7 +1732,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1726,7 +1732,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
device
=
self
.
device
),
**
seqlen_agnostic_kwargs
)
**
seqlen_agnostic_kwargs
,
**
model_kwargs
,
)
if
(
self
.
observability_config
is
not
None
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
and
self
.
observability_config
.
collect_model_forward_time
):
...
@@ -1979,7 +1987,11 @@ class CUDAGraphRunner(nn.Module):
...
@@ -1979,7 +1987,11 @@ class CUDAGraphRunner(nn.Module):
# Copy the input tensors to the input buffers.
# Copy the input tensors to the input buffers.
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
if
positions
is
not
None
:
if
positions
is
not
None
:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
# in some case like MLA, it will reuse positions in metadata
# but truncate them to the original size
# so the shape is not padded, we need to copy partial only
self
.
input_buffers
[
"positions"
][:
positions
.
shape
[
0
]].
copy_
(
positions
,
non_blocking
=
True
)
if
self
.
backend_name
!=
"NO_ATTENTION"
:
if
self
.
backend_name
!=
"NO_ATTENTION"
:
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
...
...
vllm/worker/model_runner_base.py
View file @
1fbdf957
...
@@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict(
...
@@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict(
valid_attn_kwargs
=
{}
valid_attn_kwargs
=
{}
for
field
in
dataclasses
.
fields
(
attn_backend
.
get_metadata_cls
()):
for
field
in
dataclasses
.
fields
(
attn_backend
.
get_metadata_cls
()):
if
field
.
name
in
tensor_dict
:
if
field
.
name
in
tensor_dict
:
valid_attn_kwargs
[
field
.
name
]
=
tensor_dict
.
pop
(
field
.
name
)
if
field
.
name
==
"input_positions"
:
valid_attn_kwargs
[
field
.
name
]
=
tensor_dict
[
field
.
name
]
else
:
valid_attn_kwargs
[
field
.
name
]
=
tensor_dict
.
pop
(
field
.
name
)
attn_metadata
=
attn_backend
.
make_metadata
(
**
valid_attn_kwargs
)
attn_metadata
=
attn_backend
.
make_metadata
(
**
valid_attn_kwargs
)
tensor_dict
[
"attn_metadata"
]
=
attn_metadata
tensor_dict
[
"attn_metadata"
]
=
attn_metadata
...
...
vllm/worker/worker.py
View file @
1fbdf957
...
@@ -68,10 +68,10 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -68,10 +68,10 @@ class Worker(LocalOrDistributedWorkerBase):
speculative_config
=
self
.
speculative_config
speculative_config
=
self
.
speculative_config
model_config
=
self
.
model_config
model_config
=
self
.
model_config
speculative_args
=
{}
if
speculative_config
is
None
\
speculative_args
=
{}
if
speculative_config
is
None
\
or
(
speculative_config
.
draft_model_config
.
model
==
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
==
model_config
.
model
)
\
model_config
.
hf_config
.
model_type
)
\
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
not
in
[
"medusa"
,
"mlp_speculator"
,
"eagle"
]
)
\
not
in
(
"medusa"
,
"mlp_speculator"
,
"eagle"
,
"deepseek_mtp"
)
)
\
else
{
"return_hidden_states"
:
True
}
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
...
...
vllm/worker/worker_base.py
View file @
1fbdf957
...
@@ -422,10 +422,12 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -422,10 +422,12 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return
None
return
None
model_input
,
worker_input
,
kwargs
=
inputs
model_input
,
worker_input
,
kwargs
=
inputs
num_steps
=
worker_input
.
num_steps
self
.
model_input
=
model_input
self
.
model_input
=
model_input
num_steps
=
worker_input
.
num_steps
if
(
execute_model_req
is
not
None
and
execute_model_req
.
spec_step_idx
):
kwargs
[
"spec_step_idx"
]
=
execute_model_req
.
spec_step_idx
self
.
execute_worker
(
worker_input
)
self
.
execute_worker
(
worker_input
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment