Unverified Commit 1c3ffdbe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 sampling metadata (#25345)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent c438b295
...@@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
DeepseekV3ForCausalLM) DeepseekV3ForCausalLM)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
...@@ -222,10 +221,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): ...@@ -222,10 +221,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .deepseek_v2 import (DeepseekV2DecoderLayer, from .deepseek_v2 import (DeepseekV2DecoderLayer,
...@@ -124,15 +123,13 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -124,15 +123,13 @@ class DeepSeekMultiTokenPredictor(nn.Module):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers) current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers[str(self.mtp_start_layer_idx + mtp_layer = self.layers[str(self.mtp_start_layer_idx +
current_step_idx)] current_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head, logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states), mtp_layer.shared_head(hidden_states))
sampling_metadata)
return logits return logits
...@@ -161,11 +158,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -161,11 +158,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, sampling_metadata, return self.model.compute_logits(hidden_states, spec_step_idx)
spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op
...@@ -914,10 +913,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, ...@@ -914,10 +913,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -15,7 +15,6 @@ from transformers import BatchFeature ...@@ -15,7 +15,6 @@ from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers import replace_linear_class from vllm.model_executor.models.transformers import replace_linear_class
...@@ -647,10 +646,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -647,10 +646,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -52,7 +52,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -52,7 +52,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -534,10 +533,8 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -534,10 +533,8 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -591,10 +590,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -591,10 +590,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -39,7 +39,6 @@ from vllm.config import VllmConfig ...@@ -39,7 +39,6 @@ from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -1292,11 +1291,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1292,11 +1291,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
"""compute logits""" """compute logits"""
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def _vision_forward( def _vision_forward(
self, self,
......
...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .ernie45_moe import Ernie4_5_MoeMLP from .ernie45_moe import Ernie4_5_MoeMLP
...@@ -587,10 +586,8 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -587,10 +586,8 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -36,7 +36,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -36,7 +36,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -138,12 +137,10 @@ class ErnieMultiTokenPredictor(nn.Module): ...@@ -138,12 +137,10 @@ class ErnieMultiTokenPredictor(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
lm_head: ParallelLMHead, lm_head: ParallelLMHead,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
logits = self.logits_processor(lm_head, hidden_states, logits = self.logits_processor(lm_head, hidden_states)
sampling_metadata)
return logits return logits
...@@ -180,11 +177,10 @@ class ErnieMTP(nn.Module, SupportsPP): ...@@ -180,11 +177,10 @@ class ErnieMTP(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, self.lm_head, return self.model.compute_logits(hidden_states, self.lm_head,
sampling_metadata, spec_step_idx) spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -534,10 +533,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -534,10 +533,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -517,10 +516,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -517,10 +516,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
...@@ -496,10 +495,8 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -496,10 +495,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
...@@ -675,10 +674,8 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -675,10 +674,8 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
......
...@@ -29,7 +29,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, ...@@ -29,7 +29,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -389,10 +388,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -389,10 +388,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.language_model.logits_processor( logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata) self.language_model.lm_head, hidden_states)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -412,10 +411,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -412,10 +411,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states, logits = self.logits_processor(self.model.embed_tokens, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -409,10 +408,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -409,10 +408,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states, logits = self.logits_processor(self.model.embed_tokens, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from ...attention.layers.encoder_only_attention import EncoderOnlyAttention from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
...@@ -542,10 +541,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -542,10 +541,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states, logits = self.logits_processor(self.model.embed_tokens, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -14,7 +14,6 @@ from vllm.config import VllmConfig ...@@ -14,7 +14,6 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -704,10 +703,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -704,10 +703,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant from .interfaces import SupportsQuant
...@@ -814,10 +813,8 @@ class Gemma3nForCausalLM(nn.Module): ...@@ -814,10 +813,8 @@ class Gemma3nForCausalLM(nn.Module):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata],
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states, logits = self.logits_processor(self.model.embed_tokens, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -25,7 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -25,7 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -685,10 +684,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -685,10 +684,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment