Unverified Commit b411418f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Chore] Remove Sampler from Model Code (#17084)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 2bc0f72a
...@@ -35,7 +35,7 @@ from vllm.lora.request import LoRARequest ...@@ -35,7 +35,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
...@@ -1094,6 +1094,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1094,6 +1094,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Set after load_model. # Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.sampler = get_sampler()
set_cpu_offload_max_bytes( set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3)) int(self.cache_config.cpu_offload_gb * 1024**3))
...@@ -1832,7 +1833,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1832,7 +1833,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
......
...@@ -488,8 +488,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -488,8 +488,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
device="cpu", device="cpu",
pin_memory=True) pin_memory=True)
self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( self._base_model_runner.sampler.include_gpu_probs_tensor = True
True)
if frozen_model_input.sampling_metadata: if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
True) True)
......
...@@ -18,7 +18,7 @@ from vllm.forward_context import set_forward_context ...@@ -18,7 +18,7 @@ from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap, MultiModalKwargs, MultiModalPlaceholderMap,
...@@ -410,6 +410,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -410,6 +410,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
self.sampler = get_sampler()
self.sampling_metadata_cache: SamplingMetadataCache = \ self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \ SamplingMetadataCache() \
...@@ -596,7 +597,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -596,7 +597,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment