Unverified Commit 1c3ffdbe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 sampling metadata (#25345)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent c438b295
......@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
......@@ -1208,10 +1207,8 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
......
......@@ -19,7 +19,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer,
Qwen3NextRMSNorm)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
......@@ -266,11 +265,9 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
......
......@@ -45,7 +45,6 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
......@@ -1493,10 +1492,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
......
......@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
......@@ -472,10 +471,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
......@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
......@@ -897,10 +896,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
......
......@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
......@@ -495,10 +494,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
......@@ -42,7 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
......@@ -332,10 +331,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
......@@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
......@@ -339,10 +338,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
......@@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
......@@ -405,10 +404,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
......@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
......@@ -1055,10 +1054,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
......@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
......@@ -638,10 +637,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalUUIDDict,
......@@ -798,10 +797,8 @@ class TransformersForCausalLM(TransformersBase):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
......
......@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
......@@ -616,10 +615,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
......
......@@ -30,7 +30,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
# yapf: disable
from vllm.model_executor.models.whisper import WhisperEncoder
# yapf: enable
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalUUIDDict,
......@@ -454,10 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return self.language_model.compute_logits(hidden_states)
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
......
......@@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
......@@ -936,10 +935,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
return WhisperAudioInputs(input_features=input_features)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.proj_out, hidden_states,
sampling_metadata)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.proj_out, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
......@@ -1036,7 +1035,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
"""Compute logits for next token prediction.
......@@ -1047,8 +1045,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns:
Logits for next token prediction
"""
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class SamplingMetadata:
# Placeholder until it can be safely removed.
pass
......@@ -239,7 +239,7 @@ class EagleProposer:
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
......@@ -367,8 +367,7 @@ class EagleProposer:
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
logits = self.model.compute_logits(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
......@@ -678,9 +677,7 @@ class EagleProposer:
# Get the output logits for the draft tokens.
logits = self.model.compute_logits(
draft_last_hidden_states.reshape(batch_size * level_num_drafts,
-1),
None,
)
-1))
# Sample a draft token for each child at the next tree level.
num_children = self.child_drafts_per_level[level + 1]
......
......@@ -41,7 +41,7 @@ class MedusaProposer:
) -> list[list[int]]:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None)
logits = self.model.compute_logits(blocks)
# Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
......
......@@ -2240,7 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return output
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
logits = self.model.compute_logits(sample_hidden_states)
else:
# Rare case.
assert not self.is_pooling_model
......@@ -2258,8 +2258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states,
None)
logits = self.model.compute_logits(sample_hidden_states)
model_output_broadcast_data = {}
if logits is not None:
......@@ -2706,7 +2705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc.np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
logits = self.model.compute_logits(prompt_hidden_states)
# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
......@@ -3105,7 +3104,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# To avoid breaking the sampler, we use a random tensor here instead.
hidden_states = torch.rand_like(hidden_states)
logits = self.model.compute_logits(hidden_states, None)
logits = self.model.compute_logits(hidden_states)
num_reqs = logits.size(0)
dummy_tensors = lambda v: torch.full(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment