Unverified Commit 1c3ffdbe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 sampling metadata (#25345)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent c438b295
...@@ -9,7 +9,6 @@ from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder, ...@@ -9,7 +9,6 @@ from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
LlavaMultiModalProcessor, LlavaMultiModalProcessor,
LlavaProcessingInfo) LlavaProcessingInfo)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -18,11 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -18,11 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
dummy_inputs=LlavaDummyInputsBuilder) dummy_inputs=LlavaDummyInputsBuilder)
class MyLlava(LlavaForConditionalGeneration): class MyLlava(LlavaForConditionalGeneration):
def compute_logits( def compute_logits(self,
self, hidden_states: torch.Tensor, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token # this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata) logits = super().compute_logits(hidden_states)
if logits is not None: if logits is not None:
logits.zero_() logits.zero_()
logits[:, 0] += 1.0 logits[:, 0] += 1.0
......
...@@ -6,16 +6,14 @@ from typing import Optional ...@@ -6,16 +6,14 @@ from typing import Optional
import torch import torch
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM): class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits( def compute_logits(self,
self, hidden_states: torch.Tensor, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token # this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata) logits = super().compute_logits(hidden_states)
if logits is not None: if logits is not None:
logits.zero_() logits.zero_()
logits[:, 0] += 1.0 logits[:, 0] += 1.0
......
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter) PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"SamplingMetadata",
"set_random_seed", "set_random_seed",
"BasevLLMParameter", "BasevLLMParameter",
"PackedvLLMParameter", "PackedvLLMParameter",
......
...@@ -10,7 +10,6 @@ from vllm.distributed import (tensor_model_parallel_all_gather, ...@@ -10,7 +10,6 @@ from vllm.distributed import (tensor_model_parallel_all_gather,
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -50,7 +49,6 @@ class LogitsProcessor(CustomOp): ...@@ -50,7 +49,6 @@ class LogitsProcessor(CustomOp):
self, self,
lm_head: VocabParallelEmbedding, lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if self.logits_as_input: if self.logits_as_input:
......
...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -566,10 +565,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -566,10 +565,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -399,11 +399,10 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -399,11 +399,10 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return model_output return model_output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self,
sampling_metadata) -> Optional[torch.Tensor]: hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# Compute final logits from hidden states (last pipeline rank only) # Compute final logits from hidden states (last pipeline rank only)
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......
...@@ -30,7 +30,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -30,7 +30,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -456,10 +455,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -456,10 +455,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -644,10 +643,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -644,10 +643,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -16,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( ...@@ -16,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
get_optimal_tiled_canvas) get_optimal_tiled_canvas)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
...@@ -464,7 +463,5 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -464,7 +463,5 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader) default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
...@@ -421,10 +420,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -421,10 +420,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -623,10 +622,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -623,10 +622,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType from vllm.utils import LayerBlockType
...@@ -571,10 +570,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -571,10 +570,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -12,7 +12,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, ...@@ -12,7 +12,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -704,10 +703,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -704,10 +703,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP, SupportsQuant from .interfaces import SupportsPP, SupportsQuant
...@@ -355,10 +354,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -355,10 +354,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader) default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -1046,10 +1045,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1046,10 +1045,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
# Disallow image tokens which does not include special # Disallow image tokens which does not include special
# begin-image and end-image tokens # begin-image and end-image tokens
......
...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
...@@ -437,10 +436,8 @@ class ChatGLMBaseModel(nn.Module): ...@@ -437,10 +436,8 @@ class ChatGLMBaseModel(nn.Module):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
...@@ -478,7 +477,5 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -478,7 +477,5 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name, default_weight_loader, maybe_remap_kv_scale_name,
row_parallel_weight_loader) row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -448,15 +447,14 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): ...@@ -448,15 +447,14 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
is_not_lora = hasattr(self.model.embed_tokens, 'weight') is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora: if is_not_lora:
logits = self.logits_processor(self.model.embed_tokens, logits = self.logits_processor(self.model.embed_tokens,
hidden_states, sampling_metadata) hidden_states)
else: else:
logits = self.logits_processor(self.model.embed_tokens.base_layer, logits = self.logits_processor(self.model.embed_tokens.base_layer,
hidden_states, sampling_metadata) hidden_states)
return logits return logits
......
...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -462,10 +461,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -462,10 +461,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -488,10 +487,8 @@ class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -488,10 +487,8 @@ class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment