Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1c3ffdbe
Unverified
Commit
1c3ffdbe
authored
Sep 21, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 21, 2025
Browse files
[V0 Deprecation] Remove V0 sampling metadata (#25345)
Signed-off-by:
Woosuk Kwon
<
woosuk@thinkingmachines.ai
>
parent
c438b295
Changes
141
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
28 additions
and
87 deletions
+28
-87
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+1
-4
vllm/model_executor/models/qwen3_next_mtp.py
vllm/model_executor/models/qwen3_next_mtp.py
+1
-4
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+1
-4
vllm/model_executor/models/seed_oss.py
vllm/model_executor/models/seed_oss.py
+1
-4
vllm/model_executor/models/skyworkr1v.py
vllm/model_executor/models/skyworkr1v.py
+1
-4
vllm/model_executor/models/solar.py
vllm/model_executor/models/solar.py
+2
-5
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+1
-4
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+1
-4
vllm/model_executor/models/step3_text.py
vllm/model_executor/models/step3_text.py
+2
-5
vllm/model_executor/models/step3_vl.py
vllm/model_executor/models/step3_vl.py
+1
-4
vllm/model_executor/models/tarsier.py
vllm/model_executor/models/tarsier.py
+1
-4
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+1
-4
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+2
-5
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+1
-4
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+2
-5
vllm/model_executor/models/zamba2.py
vllm/model_executor/models/zamba2.py
+1
-4
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+0
-7
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+3
-6
vllm/v1/spec_decode/medusa.py
vllm/v1/spec_decode/medusa.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-5
No files found.
vllm/model_executor/models/qwen3_next.py
View file @
1c3ffdbe
...
...
@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -1208,10 +1207,8 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
vllm/model_executor/models/qwen3_next_mtp.py
View file @
1c3ffdbe
...
...
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen3_next
import
(
Qwen3NextDecoderLayer
,
Qwen3NextRMSNorm
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
Qwen3NextConfig
...
...
@@ -266,11 +265,9 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
spec_step_idx
:
int
=
0
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
1c3ffdbe
...
...
@@ -45,7 +45,6 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.activation
import
_ACTIVATION_REGISTRY
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
...
...
@@ -1493,10 +1492,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
vllm/model_executor/models/seed_oss.py
View file @
1c3ffdbe
...
...
@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
...
...
@@ -472,10 +471,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/skyworkr1v.py
View file @
1c3ffdbe
...
...
@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.models.intern_vit
import
(
InternVisionModel
,
InternVisionPatchModel
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
convert_image_mode
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
...
...
@@ -897,10 +896,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
vllm/model_executor/models/solar.py
View file @
1c3ffdbe
...
...
@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
...
...
@@ -495,10 +494,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/stablelm.py
View file @
1c3ffdbe
...
...
@@ -42,7 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
...
...
@@ -332,10 +331,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/starcoder2.py
View file @
1c3ffdbe
...
...
@@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
...
...
@@ -339,10 +338,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/step3_text.py
View file @
1c3ffdbe
...
...
@@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
...
...
@@ -405,10 +404,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/step3_vl.py
View file @
1c3ffdbe
...
...
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
NestedTensors
)
...
...
@@ -1055,10 +1054,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/tarsier.py
View file @
1c3ffdbe
...
...
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.models.llava
import
LlavaDummyInputsBuilder
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.cache
import
BaseMultiModalProcessorCache
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargsItems
...
...
@@ -638,10 +637,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
vllm/model_executor/models/transformers.py
View file @
1c3ffdbe
...
...
@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargsItems
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalUUIDDict
,
...
...
@@ -798,10 +797,8 @@ class TransformersForCausalLM(TransformersBase):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
...
...
vllm/model_executor/models/ultravox.py
View file @
1c3ffdbe
...
...
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.model_loader
import
DefaultModelLoader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
NestedTensors
)
...
...
@@ -616,10 +615,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
vllm/model_executor/models/voxtral.py
View file @
1c3ffdbe
...
...
@@ -30,7 +30,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
# yapf: disable
from
vllm.model_executor.models.whisper
import
WhisperEncoder
# yapf: enable
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
MultiModalUUIDDict
,
...
...
@@ -454,10 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
self
.
language_model
.
compute_logits
(
hidden_states
)
@
classmethod
def
get_speech_to_text_config
(
cls
,
model_config
:
ModelConfig
,
...
...
vllm/model_executor/models/whisper.py
View file @
1c3ffdbe
...
...
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
NestedTensors
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
)
...
...
@@ -936,10 +935,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
return
WhisperAudioInputs
(
input_features
=
input_features
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
proj_out
,
hidden_states
,
sampling_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
proj_out
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/zamba2.py
View file @
1c3ffdbe
...
...
@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
HasInnerState
,
IsHybrid
...
...
@@ -1036,7 +1035,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
"""Compute logits for next token prediction.
...
...
@@ -1047,8 +1045,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns:
Logits for next token prediction
"""
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/sampling_metadata.py
deleted
100644 → 0
View file @
c438b295
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class
SamplingMetadata
:
# Placeholder until it can be safely removed.
pass
vllm/v1/spec_decode/eagle.py
View file @
1c3ffdbe
...
...
@@ -239,7 +239,7 @@ class EagleProposer:
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
...
...
@@ -367,8 +367,7 @@ class EagleProposer:
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
])
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
...
...
@@ -678,9 +677,7 @@ class EagleProposer:
# Get the output logits for the draft tokens.
logits
=
self
.
model
.
compute_logits
(
draft_last_hidden_states
.
reshape
(
batch_size
*
level_num_drafts
,
-
1
),
None
,
)
-
1
))
# Sample a draft token for each child at the next tree level.
num_children
=
self
.
child_drafts_per_level
[
level
+
1
]
...
...
vllm/v1/spec_decode/medusa.py
View file @
1c3ffdbe
...
...
@@ -41,7 +41,7 @@ class MedusaProposer:
)
->
list
[
list
[
int
]]:
# Generate blocks and compute logits
blocks
=
self
.
model
(
target_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
blocks
,
None
)
logits
=
self
.
model
.
compute_logits
(
blocks
)
# Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1c3ffdbe
...
...
@@ -2240,7 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
output
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
else
:
# Rare case.
assert
not
self
.
is_pooling_model
...
...
@@ -2258,8 +2258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
=
None
else
:
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
model_output_broadcast_data
=
{}
if
logits
is
not
None
:
...
...
@@ -2706,7 +2705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_idx
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
offset
=
self
.
query_start_loc
.
np
[
req_idx
].
item
()
prompt_hidden_states
=
hidden_states
[
offset
:
offset
+
num_logits
]
logits
=
self
.
model
.
compute_logits
(
prompt_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
prompt_hidden_states
)
# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
...
...
@@ -3105,7 +3104,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# To avoid breaking the sampler, we use a random tensor here instead.
hidden_states
=
torch
.
rand_like
(
hidden_states
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
num_reqs
=
logits
.
size
(
0
)
dummy_tensors
=
lambda
v
:
torch
.
full
(
...
...
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment