Unverified Commit b411418f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Chore] Remove Sampler from Model Code (#17084)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 2bc0f72a
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -409,7 +408,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -409,7 +408,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -466,14 +464,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -466,14 +464,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -58,7 +58,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -58,7 +58,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -298,7 +297,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -298,7 +297,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = get_sampler()
self.media_placeholder: int = self.config.media_placeholder_token_id self.media_placeholder: int = self.config.media_placeholder_token_id
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_world_size = get_tensor_model_parallel_world_size() self.tp_world_size = get_tensor_model_parallel_world_size()
...@@ -409,7 +407,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -409,7 +407,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner from # NOTE: In v1, inputs_embeds is always generated at model runner from
...@@ -447,14 +445,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -447,14 +445,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata, **kwargs) sampling_metadata, **kwargs)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
config = self.config.text_config config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
......
...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -515,8 +514,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -515,8 +514,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -551,11 +548,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -551,11 +548,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union, cast) TypeVar, Union, cast)
...@@ -23,7 +22,6 @@ from vllm.model_executor.layers.activation import get_act_fn ...@@ -23,7 +22,6 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -546,13 +544,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -546,13 +544,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -763,13 +754,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -763,13 +754,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, TypeVar, Union) Protocol, Set, Tuple, TypedDict, TypeVar, Union)
...@@ -13,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import ( ...@@ -13,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
...@@ -250,13 +248,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -250,13 +248,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -585,13 +576,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -585,13 +576,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
...@@ -12,7 +11,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig, ...@@ -12,7 +11,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig,
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -301,13 +299,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -301,13 +299,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_video_pixel_values( def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, List[torch.Tensor]]:
...@@ -469,13 +460,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -469,13 +460,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -16,7 +15,6 @@ from typing_extensions import NotRequired ...@@ -16,7 +15,6 @@ from typing_extensions import NotRequired
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -455,13 +453,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -455,13 +453,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -957,13 +948,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -957,13 +948,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -213,7 +212,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, ...@@ -213,7 +212,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors) self.backbone.make_empty_intermediate_tensors)
...@@ -267,14 +265,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, ...@@ -267,14 +265,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards) MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -208,7 +207,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, ...@@ -208,7 +207,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors) self.backbone.make_empty_intermediate_tensors)
...@@ -282,14 +280,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, ...@@ -282,14 +280,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -553,7 +552,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -553,7 +552,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(unpadded_vocab_size, self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -584,14 +582,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -584,14 +582,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
import math import math
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union) Union)
...@@ -40,7 +40,6 @@ from vllm.config import VllmConfig ...@@ -40,7 +40,6 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed) get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
...@@ -758,13 +757,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -758,13 +757,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors) self.llm.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.llm, "sampler"):
return self.llm.sampler
return get_sampler()
def _parse_and_validate_vision_input( def _parse_and_validate_vision_input(
self, self,
modality: str, modality: str,
...@@ -946,14 +938,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -946,14 +938,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.llm.compute_logits(hidden_states, sampling_metadata) return self.llm.compute_logits(hidden_states, sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -994,7 +993,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -994,7 +993,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size) self.config.vocab_size)
self.sampler = Sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
...@@ -1030,16 +1028,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1030,16 +1028,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors: device: torch.device) -> IntermediateTensors:
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union) TypeVar, Union)
...@@ -19,7 +18,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -19,7 +18,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -435,13 +433,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -435,13 +433,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -598,13 +589,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -598,13 +589,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -40,7 +40,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -454,7 +453,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -454,7 +453,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -481,14 +479,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -481,14 +479,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
......
...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -372,7 +371,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -372,7 +371,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -399,14 +397,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -399,14 +397,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -1211,7 +1210,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1211,7 +1210,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
self.logits_processor = LogitsProcessor(config.output_hidden_states, self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size) config.text_config.vocab_size)
self.sampler = get_sampler()
def compute_logits( def compute_logits(
self, self,
...@@ -1222,14 +1220,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1222,14 +1220,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states, sampling_metadata) hidden_states, sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def unpack_data(self, def unpack_data(self,
image_data: Union[List[torch.Tensor], torch.Tensor], image_data: Union[List[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor: padding_value=0) -> torch.Tensor:
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
# limitations under the License. # limitations under the License.
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from functools import cached_property
from itertools import tee from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
...@@ -38,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -38,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -682,13 +680,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -682,13 +680,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
# num_images, 1, num_chunks, channel, image_size, image_size # num_images, 1, num_chunks, channel, image_size, image_size
...@@ -785,10 +776,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -785,10 +776,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def separate_weights( def separate_weights(
self, self,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[Tuple[str, torch.Tensor]],
......
...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -1394,7 +1393,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1394,7 +1393,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.logits_processor = LogitsProcessor(config.embedding_size self.logits_processor = LogitsProcessor(config.embedding_size
or config.vocab_size) or config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -1506,7 +1504,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1506,7 +1504,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> torch.Tensor:
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
...@@ -1532,14 +1530,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1532,14 +1530,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -298,7 +297,6 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -298,7 +297,6 @@ class MPTForCausalLM(nn.Module, SupportsPP):
prefix=maybe_prefix(prefix, "transformer")) prefix=maybe_prefix(prefix, "transformer"))
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
...@@ -325,14 +323,6 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -325,14 +323,6 @@ class MPTForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -416,8 +415,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -416,8 +415,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -444,14 +441,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -444,14 +441,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment