Unverified Commit 1c3ffdbe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 sampling metadata (#25345)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent c438b295
...@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader) default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -1208,10 +1207,8 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -1208,10 +1207,8 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states, return self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer, from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer,
Qwen3NextRMSNorm) Qwen3NextRMSNorm)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.transformers_utils.configs import Qwen3NextConfig
...@@ -266,11 +265,9 @@ class Qwen3NextMTP(nn.Module, SupportsPP): ...@@ -266,11 +265,9 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states, return self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -45,7 +45,6 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -45,7 +45,6 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -1493,10 +1492,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1493,10 +1492,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -472,10 +471,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -472,10 +471,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.intern_vit import (InternVisionModel, from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -897,10 +896,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -897,10 +896,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -495,10 +494,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -495,10 +494,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds) inputs_embeds)
return model_output return model_output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -332,10 +331,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -332,10 +331,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -339,10 +338,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -339,10 +338,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -405,10 +404,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -405,10 +404,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems, NestedTensors)
...@@ -1055,10 +1054,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1055,10 +1054,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
...@@ -638,10 +637,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -638,10 +637,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalUUIDDict, MultiModalInputs, MultiModalUUIDDict,
...@@ -798,10 +797,8 @@ class TransformersForCausalLM(TransformersBase): ...@@ -798,10 +797,8 @@ class TransformersForCausalLM(TransformersBase):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
......
...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn ...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems, NestedTensors)
...@@ -616,10 +615,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -616,10 +615,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> torch.Tensor: return self.language_model.compute_logits(hidden_states)
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -30,7 +30,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys ...@@ -30,7 +30,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
# yapf: disable # yapf: disable
from vllm.model_executor.models.whisper import WhisperEncoder from vllm.model_executor.models.whisper import WhisperEncoder
# yapf: enable # yapf: enable
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalUUIDDict, MultiModalKwargsItems, MultiModalUUIDDict,
...@@ -454,10 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -454,10 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
@classmethod @classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig, def get_speech_to_text_config(cls, model_config: ModelConfig,
......
...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -936,10 +935,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -936,10 +935,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
return WhisperAudioInputs(input_features=input_features) return WhisperAudioInputs(input_features=input_features)
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.proj_out, hidden_states)
logits = self.logits_processor(self.proj_out, hidden_states,
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid from .interfaces import HasInnerState, IsHybrid
...@@ -1036,7 +1035,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -1036,7 +1035,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
"""Compute logits for next token prediction. """Compute logits for next token prediction.
...@@ -1047,8 +1045,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -1047,8 +1045,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns: Returns:
Logits for next token prediction Logits for next token prediction
""" """
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class SamplingMetadata:
# Placeholder until it can be safely removed.
pass
...@@ -239,7 +239,7 @@ class EagleProposer: ...@@ -239,7 +239,7 @@ class EagleProposer:
else: else:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1:
...@@ -367,8 +367,7 @@ class EagleProposer: ...@@ -367,8 +367,7 @@ class EagleProposer:
else: else:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size])
None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
...@@ -678,9 +677,7 @@ class EagleProposer: ...@@ -678,9 +677,7 @@ class EagleProposer:
# Get the output logits for the draft tokens. # Get the output logits for the draft tokens.
logits = self.model.compute_logits( logits = self.model.compute_logits(
draft_last_hidden_states.reshape(batch_size * level_num_drafts, draft_last_hidden_states.reshape(batch_size * level_num_drafts,
-1), -1))
None,
)
# Sample a draft token for each child at the next tree level. # Sample a draft token for each child at the next tree level.
num_children = self.child_drafts_per_level[level + 1] num_children = self.child_drafts_per_level[level + 1]
......
...@@ -41,7 +41,7 @@ class MedusaProposer: ...@@ -41,7 +41,7 @@ class MedusaProposer:
) -> list[list[int]]: ) -> list[list[int]]:
# Generate blocks and compute logits # Generate blocks and compute logits
blocks = self.model(target_hidden_states) blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None) logits = self.model.compute_logits(blocks)
# Get draft tokens and transpose the result # Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
......
...@@ -2240,7 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2240,7 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return output return output
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states)
else: else:
# Rare case. # Rare case.
assert not self.is_pooling_model assert not self.is_pooling_model
...@@ -2258,8 +2258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2258,8 +2258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = None logits = None
else: else:
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, logits = self.model.compute_logits(sample_hidden_states)
None)
model_output_broadcast_data = {} model_output_broadcast_data = {}
if logits is not None: if logits is not None:
...@@ -2706,7 +2705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2706,7 +2705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_idx = self.input_batch.req_id_to_index[req_id] req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc.np[req_idx].item() offset = self.query_start_loc.np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits] prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None) logits = self.model.compute_logits(prompt_hidden_states)
# Get the "target" tokens for each index. For prompt at index i, # Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want # the token at prompt index i+1 is the "sampled" token we want
...@@ -3105,7 +3104,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3105,7 +3104,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# To avoid breaking the sampler, we use a random tensor here instead. # To avoid breaking the sampler, we use a random tensor here instead.
hidden_states = torch.rand_like(hidden_states) hidden_states = torch.rand_like(hidden_states)
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states)
num_reqs = logits.size(0) num_reqs = logits.size(0)
dummy_tensors = lambda v: torch.full( dummy_tensors = lambda v: torch.full(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment