Unverified Commit b411418f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Chore] Remove Sampler from Model Code (#17084)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 2bc0f72a
...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -497,7 +496,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -497,7 +496,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -524,14 +522,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -524,14 +522,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union) Union)
...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig) GPTQMarlinConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -1112,13 +1111,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1112,13 +1111,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models. # seems to avoid vision encoder sections for some models.
...@@ -1400,13 +1392,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1400,13 +1392,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
......
...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -283,7 +282,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -283,7 +282,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -311,14 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -311,14 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -44,7 +44,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -44,7 +44,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -494,7 +493,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP): ...@@ -494,7 +493,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -521,14 +519,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP): ...@@ -521,14 +519,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
# -------------------------------------------------------- # --------------------------------------------------------
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch import torch
...@@ -21,7 +20,6 @@ from vllm.config import VllmConfig ...@@ -21,7 +20,6 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel, from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -699,13 +697,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -699,13 +697,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
(llm_quant_config is not None): (llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model") quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _init_vision_model( def _init_vision_model(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -908,7 +899,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -908,7 +899,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]: ) -> IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None input_ids = None
...@@ -946,13 +937,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -946,13 +937,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
skip_prefixes = [ skip_prefixes = [
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -41,7 +41,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -418,8 +417,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -418,8 +417,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -440,14 +437,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -440,14 +437,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -310,7 +309,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -310,7 +309,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -337,14 +335,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -337,14 +335,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -317,7 +316,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -317,7 +316,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
) )
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -344,14 +342,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -344,14 +342,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -396,8 +395,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -396,8 +395,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -435,12 +432,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -435,12 +432,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
...@@ -18,7 +17,6 @@ from vllm.config import VllmConfig ...@@ -18,7 +17,6 @@ from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -438,13 +436,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -438,13 +436,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
Get the module prefix in multimodal models Get the module prefix in multimodal models
...@@ -628,13 +619,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -628,13 +619,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
......
...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -669,7 +668,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -669,7 +668,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -724,14 +722,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -724,14 +722,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
......
...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import ( ...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards) MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -870,7 +869,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): ...@@ -870,7 +869,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
# Initialize logits processing and sampling # Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings. """Convert input token IDs to embeddings.
...@@ -1004,23 +1002,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): ...@@ -1004,23 +1002,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""Sample next tokens from computed logits.
Args:
logits: Computed logits for next token prediction
sampling_metadata: Metadata for sampling process
Returns:
Sampled tokens and related sampling information
"""
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -295,7 +295,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): ...@@ -295,7 +295,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model_runner.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
......
...@@ -50,11 +50,10 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): ...@@ -50,11 +50,10 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
def set_include_gpu_probs_tensor(self) -> None: def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiStepWorker # Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True self.model_runner.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None: def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( self.model_runner.sampler.should_modify_greedy_probs_inplace = True
True)
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
......
...@@ -410,9 +410,9 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -410,9 +410,9 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
NOTE(cade): This will require a special check if the proposer worker NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation). does not have a sampler (e.g. ngram speculation).
""" """
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor (self.scorer_worker.model_runner.sampler.include_gpu_probs_tensor
) = True ) = True
(self.scorer_worker.model_runner.model.sampler. (self.scorer_worker.model_runner.sampler.
should_modify_greedy_probs_inplace) = True should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace() self.proposer_worker.set_should_modify_greedy_probs_inplace()
......
...@@ -38,6 +38,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ...@@ -38,6 +38,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
...@@ -153,6 +154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -153,6 +154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
# Sampler
self.sampler = Sampler()
# Lazy initialization # Lazy initialization
# self.model: nn.Module # Set after load_model # self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
...@@ -1096,7 +1100,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1096,7 +1100,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed. # Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None: if spec_decode_metadata is None:
sampler_output = self.model.sample( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
...@@ -1106,7 +1110,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1106,7 +1110,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# logits tensor. This means any in-place operations on bonus_logits # logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor. # won't affect the original logits tensor.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.model.sample( sampler_output = self.sampler(
logits=bonus_logits, logits=bonus_logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
...@@ -1383,8 +1387,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1383,8 +1387,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
# Compute prompt logprobs. # Compute prompt logprobs.
logprobs = self.model.sampler.compute_logprobs(logits) logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.model.sampler.gather_logprobs( token_ids, logprobs, ranks = self.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids) logprobs, num_prompt_logprobs, tgt_token_ids)
# Transfer GPU->CPU async. # Transfer GPU->CPU async.
...@@ -1502,8 +1506,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1502,8 +1506,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
bad_words_token_ids={}, bad_words_token_ids={},
) )
try: try:
sampler_output = self.model.sample( sampler_output = self.sampler(logits=logits,
logits=logits, sampling_metadata=dummy_metadata) sampling_metadata=dummy_metadata)
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e): if 'out of memory' in str(e):
raise RuntimeError( raise RuntimeError(
......
...@@ -316,7 +316,7 @@ class CPUEncoderDecoderModelRunner( ...@@ -316,7 +316,7 @@ class CPUEncoderDecoderModelRunner(
return [] return []
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
......
...@@ -19,7 +19,7 @@ from vllm.lora.request import LoRARequest ...@@ -19,7 +19,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
...@@ -490,6 +490,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -490,6 +490,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
# Set after load_model. # Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.sampler = get_sampler()
if hasattr(self, "_builder_cls"): if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls` # multi-step model runner does not have `_builder_cls`
...@@ -545,11 +546,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -545,11 +546,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
return self.builder.build() # type: ignore return self.builder.build() # type: ignore
# sampler property will be used by spec_decode_worker
@property
def sampler(self):
return self.model.sampler
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
...@@ -677,7 +673,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): ...@@ -677,7 +673,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
return [] return []
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
......
...@@ -205,7 +205,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -205,7 +205,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
......
...@@ -41,7 +41,7 @@ from vllm.lora.request import LoRARequest ...@@ -41,7 +41,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -316,6 +316,7 @@ class HpuModelAdapter: ...@@ -316,6 +316,7 @@ class HpuModelAdapter:
def __init__(self, model, vllm_config): def __init__(self, model, vllm_config):
self.model = model self.model = model
self.sampler = get_sampler()
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true'] '0').lower() in ['1', 'true']
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -454,7 +455,7 @@ class HpuModelAdapter: ...@@ -454,7 +455,7 @@ class HpuModelAdapter:
return self.model.compute_logits(*args, **kwargs) return self.model.compute_logits(*args, **kwargs)
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
return self.model.sample(*args, **kwargs) return self.sampler(*args, **kwargs)
class PreparePromptMetadata(NamedTuple): class PreparePromptMetadata(NamedTuple):
...@@ -2167,7 +2168,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): ...@@ -2167,7 +2168,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
# in case of multi-step scheduling # in case of multi-step scheduling
# we only want to pythonize in the last step # we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True sampling_metadata.skip_sampler_cpu_output = True
self.model.model.sampler.include_gpu_probs_tensor = True self.model.sampler.include_gpu_probs_tensor = True
cache_orig_output_tokens_len: List[Dict] = [] cache_orig_output_tokens_len: List[Dict] = []
def try_revert_dummy_output_tokens(): def try_revert_dummy_output_tokens():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment