Unverified Commit 1c3ffdbe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 sampling metadata (#25345)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent c438b295
...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -289,10 +288,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -289,10 +288,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -52,7 +52,6 @@ from vllm.distributed import (get_tensor_model_parallel_world_size, ...@@ -52,7 +52,6 @@ from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state) parallel_state)
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -1654,10 +1653,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1654,10 +1653,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -703,10 +702,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -703,10 +702,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name
...@@ -155,15 +154,13 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ...@@ -155,15 +154,13 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers) current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers[str(self.mtp_start_layer_idx + mtp_layer = self.layers[str(self.mtp_start_layer_idx +
current_step_idx)] current_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head, logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states), mtp_layer.shared_head(hidden_states))
sampling_metadata)
return logits return logits
...@@ -192,11 +189,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP): ...@@ -192,11 +189,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, sampling_metadata, return self.model.compute_logits(hidden_states, spec_step_idx)
spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from ..layers.pooler import DispatchPooler, Pooler from ..layers.pooler import DispatchPooler, Pooler
...@@ -307,10 +306,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -307,10 +306,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -329,10 +328,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -329,10 +328,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -329,10 +328,9 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -329,10 +328,9 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias) self.lm_head.bias)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -321,10 +320,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -321,10 +320,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states, logits = self.logits_processor(self.embed_out, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -670,10 +669,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP): ...@@ -670,10 +669,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP):
return self.model(input_ids, positions, intermediate_tensors, return self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -463,11 +462,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -463,11 +462,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds) inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(self,
self, hidden_states: torch.Tensor, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
......
...@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -776,12 +775,8 @@ class GraniteSpeechForConditionalGeneration( ...@@ -776,12 +775,8 @@ class GraniteSpeechForConditionalGeneration(
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits( return self.language_model.compute_logits(hidden_states)
hidden_states,
sampling_metadata,
)
def load_weights( def load_weights(
self, self,
......
...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -511,11 +510,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -511,11 +510,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(self,
self, hidden_states: torch.Tensor, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
......
...@@ -32,7 +32,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -32,7 +32,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType from vllm.utils import LayerBlockType
...@@ -672,10 +671,8 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, ...@@ -672,10 +671,8 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
...@@ -311,11 +310,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -311,11 +310,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(self,
self, hidden_states: torch.Tensor, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
......
...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -528,10 +527,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -528,10 +527,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -54,7 +54,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -54,7 +54,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
...@@ -1004,10 +1003,8 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -1004,10 +1003,8 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
......
...@@ -31,7 +31,6 @@ from transformers.modeling_utils import no_init_weights ...@@ -31,7 +31,6 @@ from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -962,10 +961,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -962,10 +961,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights( def load_weights(
self, self,
......
...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -738,10 +737,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -738,10 +737,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -13,11 +13,9 @@ from vllm.utils import supports_kw ...@@ -13,11 +13,9 @@ from vllm.utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.sampling_metadata import SamplingMetadata
else: else:
VllmConfig = Any VllmConfig = Any
Pooler = Any Pooler = Any
SamplingMetadata = Any
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -100,7 +98,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): ...@@ -100,7 +98,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
def compute_logits( def compute_logits(
self, self,
hidden_states: T, hidden_states: T,
sampling_metadata: SamplingMetadata,
) -> Optional[T]: ) -> Optional[T]:
"""Return `None` if TP rank > 0.""" """Return `None` if TP rank > 0."""
... ...
......
...@@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -358,10 +357,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -358,10 +357,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states, logits = self.logits_processor(self.output, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment