Unverified Commit b411418f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Chore] Remove Sampler from Model Code (#17084)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 2bc0f72a
......@@ -3,7 +3,6 @@
import math
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
......@@ -14,7 +13,6 @@ from transformers import BartTokenizer, BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead,
......@@ -673,7 +671,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
self.logits_processor = LogitsProcessor(self.vocab_size,
config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
......@@ -716,11 +713,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......@@ -929,12 +921,6 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise NotImplementedError(
'Florence2 only supports COSINE as temporal embedding.')
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
......@@ -1110,13 +1096,6 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -27,7 +27,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -270,10 +269,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
......@@ -387,14 +382,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.lm_head, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.language_model.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -388,7 +387,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model = GemmaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -415,14 +413,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -388,7 +387,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -415,14 +413,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -493,7 +492,6 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -521,14 +519,6 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Literal, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
......@@ -12,7 +12,6 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -503,10 +502,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def dtype(self):
return next(self.parameters()).dtype
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
......@@ -607,7 +602,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
**kwargs: object) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
......@@ -704,13 +699,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
......@@ -267,7 +266,6 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -295,14 +293,6 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -255,7 +254,6 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.lm_head = self.lm_head.tie_weights(self.transformer.wte)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
......@@ -282,14 +280,6 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
......
......@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -302,7 +301,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
......@@ -329,14 +327,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -306,7 +305,6 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
......@@ -333,14 +331,6 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -299,7 +298,6 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
......@@ -326,14 +324,6 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -441,8 +440,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -464,14 +461,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -391,8 +390,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 /
self.config.logits_scaling)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -428,14 +425,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -20,7 +20,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -295,8 +294,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 /
self.config.logits_scaling)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -332,14 +329,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -39,7 +39,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -521,7 +520,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config.vocab_size,
self.output_multiplier_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -551,14 +549,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = ["rotary_emb.inv_freq"]
......
......@@ -28,7 +28,6 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -603,7 +602,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if self.config.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.text_model.wte.weight
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
......@@ -754,14 +752,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -13,7 +13,6 @@ from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -103,14 +102,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
"""Return `None` if TP rank > 0."""
...
def sample(
self,
logits: T,
sampling_metadata: "SamplingMetadata",
) -> "SamplerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_text_generation_model(
......
......@@ -23,7 +23,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -336,7 +335,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -363,14 +361,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......@@ -423,7 +413,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
prefix=prefix,
model_type=model_type)
for attr in ("output", "logits_processor", "sampler"):
for attr in ("output", "logits_processor"):
delattr(self, attr)
config = vllm_config.model_config.hf_config
......
......@@ -8,7 +8,6 @@
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch
......@@ -20,7 +19,6 @@ from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -698,13 +696,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _init_vision_model(
self,
config: PretrainedConfig,
......@@ -903,7 +894,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
......@@ -941,13 +932,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
......
......@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -308,7 +307,6 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
config.mup_width_scale)
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
......@@ -335,14 +333,6 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment