Unverified Commit b411418f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Chore] Remove Sampler from Model Code (#17084)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 2bc0f72a
......@@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -497,7 +496,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -524,14 +522,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -24,7 +24,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
......@@ -51,7 +51,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -1112,13 +1111,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
......@@ -1400,13 +1392,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
......
......@@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
......@@ -283,7 +282,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -311,14 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -44,7 +44,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -494,7 +493,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -521,14 +519,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -8,7 +8,6 @@
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch
......@@ -21,7 +20,6 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -699,13 +697,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _init_vision_model(
self,
config: PretrainedConfig,
......@@ -908,7 +899,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
......@@ -946,13 +937,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = [
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -418,8 +417,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -440,14 +437,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......
......@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -310,7 +309,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -337,14 +335,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -317,7 +316,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -344,14 +342,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -396,8 +395,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -435,12 +432,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
......
......@@ -3,7 +3,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
......@@ -18,7 +17,6 @@ from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -438,13 +436,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
......@@ -628,13 +619,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
......
......@@ -21,7 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -669,7 +668,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
......@@ -724,14 +722,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
......
......@@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -870,7 +869,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
# Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings.
......@@ -1004,23 +1002,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""Sample next tokens from computed logits.
Args:
logits: Computed logits for next token prediction
sampling_metadata: Metadata for sampling process
Returns:
Sampled tokens and related sampling information
"""
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -295,7 +295,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
output = self.model_runner.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
......
......@@ -50,11 +50,10 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
self.model_runner.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
@torch.inference_mode()
def sampler_output(
......
......@@ -410,9 +410,9 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
(self.scorer_worker.model_runner.sampler.include_gpu_probs_tensor
) = True
(self.scorer_worker.model_runner.model.sampler.
(self.scorer_worker.model_runner.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
......
......@@ -38,6 +38,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
......@@ -153,6 +154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Sampler
self.sampler = Sampler()
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
......@@ -1096,7 +1100,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
sampler_output = self.model.sample(
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
......@@ -1106,7 +1110,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.model.sample(
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
......@@ -1383,8 +1387,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
# Compute prompt logprobs.
logprobs = self.model.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids)
# Transfer GPU->CPU async.
......@@ -1502,8 +1506,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
bad_words_token_ids={},
)
try:
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
sampler_output = self.sampler(logits=logits,
sampling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
......
......@@ -316,7 +316,7 @@ class CPUEncoderDecoderModelRunner(
return []
# Sample the next token.
output = self.model.sample(
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
......
......@@ -19,7 +19,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
......@@ -490,6 +490,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
self.model: nn.Module # Set after init_Model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.sampler = get_sampler()
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
......@@ -545,11 +546,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
return self.builder.build() # type: ignore
# sampler property will be used by spec_decode_worker
@property
def sampler(self):
return self.model.sampler
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
......@@ -677,7 +673,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
return []
# Sample the next token.
output = self.model.sample(
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
......
......@@ -205,7 +205,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.model.sample(
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
......
......@@ -41,7 +41,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model
......@@ -316,6 +316,7 @@ class HpuModelAdapter:
def __init__(self, model, vllm_config):
self.model = model
self.sampler = get_sampler()
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.vllm_config = vllm_config
......@@ -454,7 +455,7 @@ class HpuModelAdapter:
return self.model.compute_logits(*args, **kwargs)
def sample(self, *args, **kwargs):
return self.model.sample(*args, **kwargs)
return self.sampler(*args, **kwargs)
class PreparePromptMetadata(NamedTuple):
......@@ -2167,7 +2168,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
self.model.model.sampler.include_gpu_probs_tensor = True
self.model.sampler.include_gpu_probs_tensor = True
cache_orig_output_tokens_len: List[Dict] = []
def try_revert_dummy_output_tokens():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment