Unverified Commit 26e673fe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 Sequence class & Sampler (#25332)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent 65a5910c
...@@ -48,10 +48,10 @@ from vllm.distributed import (cleanup_dist_env_and_memory, ...@@ -48,10 +48,10 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
initialize_model_parallel) initialize_model_parallel)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.sequence import Logprob
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import set_default_torch_num_threads from vllm.utils import set_default_torch_num_threads
......
...@@ -7,8 +7,8 @@ from typing import Optional ...@@ -7,8 +7,8 @@ from typing import Optional
import pytest import pytest
from transformers import AutoModelForSpeechSeq2Seq from transformers import AutoModelForSpeechSeq2Seq
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SampleLogprobs
from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput,
VllmRunner) VllmRunner)
......
...@@ -12,10 +12,10 @@ from huggingface_hub import snapshot_download ...@@ -12,10 +12,10 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.multimodal.image import convert_image_mode, rescale_image_size
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
PromptImageInput, VllmRunner) PromptImageInput, VllmRunner)
......
...@@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk ...@@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor from transformers import AutoProcessor
from vllm import SamplingParams, TextPrompt, TokensPrompt from vllm import SamplingParams, TextPrompt, TokensPrompt
from vllm.logprobs import Logprob, SampleLogprobs
from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs
from ....utils import VLLM_PATH, large_gpu_test from ....utils import VLLM_PATH, large_gpu_test
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
......
...@@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature, ...@@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
GenerationConfig, GenerationMixin) GenerationConfig, GenerationMixin)
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.sequence import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .....conftest import HfRunner, ImageAsset, ImageTestAssets
......
...@@ -12,7 +12,7 @@ from transformers import AutoModelForCausalLM ...@@ -12,7 +12,7 @@ from transformers import AutoModelForCausalLM
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import RunnerOption from vllm.config import RunnerOption
from vllm.sequence import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset, from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
......
...@@ -12,7 +12,7 @@ from transformers import PretrainedConfig ...@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig, ModelDType, RunnerOption from vllm.config import ModelConfig, ModelDType, RunnerOption
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from .registry import HF_EXAMPLE_MODELS from .registry import HF_EXAMPLE_MODELS
......
...@@ -8,10 +8,7 @@ import pytest ...@@ -8,10 +8,7 @@ import pytest
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
...@@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast): ...@@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast):
assert decoded_text == '' assert decoded_text == ''
assert out_ids == [len(tokenizer)] assert out_ids == [len(tokenizer)]
@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
tokenizer = get_tokenizer(
tokenizer_name,
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
trust_remote_code=False,
revision=None,
)
return Detokenizer(tokenizer)
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> list[int]:
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or []
return Sequence(
seq_id=0,
inputs=token_inputs(prompt_token_ids),
block_size=16,
)
def create_dummy_logprobs(
complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
return [{
token_id: Logprob(logprob=0.0),
token_id + 1: Logprob(logprob=0.1)
} for token_id in complete_sequence_token_ids]
def create_dummy_prompt_logprobs(
complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: list[Optional[dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
logprobs=2)
# Run sequentially.
seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token: list[str] = []
sequential_logprobs_text_other_token: list[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs):
seq.append_token_id(new_token, logprobs)
detokenizer.decode_sequence_inplace(seq, sampling_params)
sequential_logprobs_text_chosen_token.append(
seq.output_logprobs[-1][new_token].decoded_token)
sequential_logprobs_text_other_token.append(
seq.output_logprobs[-1][new_token + 1].decoded_token)
sequential_result = seq.output_text
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token)
if not skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
assert sequential_result == complete_sequence
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer, MistralTokenizer):
skip_special_tokens = False
else:
pytest.skip("MistralTokenizers don't support "
"skip_special_tokens=False")
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1)
# Run sequentially.
seq = create_sequence(complete_sequence_token_ids)
seq_group = SequenceGroup(request_id="1",
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group,
dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
1:] # type: ignore
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenizer = detokenizer.tokenizer
text_full = tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0],
skip_special_tokens=skip_special_tokens)
text = text_full[len(text_first):]
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that the first logprob is None.
assert text == "".join([
logprobs[token_id].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
...@@ -12,7 +12,7 @@ from partial_json_parser.core.options import Allow ...@@ -12,7 +12,7 @@ from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
ToolCall) ToolCall)
from vllm.entrypoints.openai.tool_parsers import JambaToolParser from vllm.entrypoints.openai.tool_parsers import JambaToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
MODEL = "ai21labs/Jamba-tiny-dev" MODEL = "ai21labs/Jamba-tiny-dev"
......
...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ToolCall) ToolCall)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser) Qwen3CoderToolParser)
from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
......
...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall, DeltaMessage, FunctionCall,
ToolCall) ToolCall)
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available # Use a common model that is likely to be available
......
...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall, DeltaMessage, FunctionCall,
ToolCall) ToolCall)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available # Use a common model that is likely to be available
......
...@@ -12,9 +12,9 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, ...@@ -12,9 +12,9 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
STOP_STRINGS, STOP_STRINGS,
DummyOutputProcessorTestVectors, DummyOutputProcessorTestVectors,
MockEngineCore) MockEngineCore)
from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import (OutputProcessor, from vllm.v1.engine.output_processor import (OutputProcessor,
......
...@@ -15,10 +15,10 @@ from vllm.config import VllmConfig ...@@ -15,10 +15,10 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils import make_async from vllm.utils import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -17,12 +17,12 @@ from vllm.executor.msgspec_utils import encode_hook ...@@ -17,12 +17,12 @@ from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
ray) ray)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method, from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async) get_ip, get_open_port, make_async)
from vllm.v1.outputs import SamplerOutput
if ray is not None: if ray is not None:
from ray.actor import ActorHandle from ray.actor import ActorHandle
......
...@@ -7,15 +7,7 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, ...@@ -7,15 +7,7 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, embeds_inputs, build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext, from .registry import InputContext, InputProcessingContext
InputRegistry)
INPUT_REGISTRY = InputRegistry()
"""
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
target model.
"""
__all__ = [ __all__ = [
"DataPrompt", "DataPrompt",
...@@ -36,9 +28,6 @@ __all__ = [ ...@@ -36,9 +28,6 @@ __all__ = [
"build_explicit_enc_dec_prompt", "build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list", "to_enc_dec_tuple_list",
"zip_enc_dec_prompts", "zip_enc_dec_prompts",
"INPUT_REGISTRY",
"DummyData",
"InputContext", "InputContext",
"InputProcessingContext", "InputProcessingContext",
"InputRegistry",
] ]
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union from typing import TYPE_CHECKING, Any, Union
import torch import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
...@@ -15,16 +15,9 @@ from vllm.utils.jsontree import JSONTree, json_map_leaves ...@@ -15,16 +15,9 @@ from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
MultiModalRegistry)
from vllm.sequence import SequenceData
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
else: else:
ModelConfig = Any ModelConfig = Any
MultiModalDataDict = Any
MultiModalPlaceholderDict = Any
MultiModalRegistry = Any
SequenceData = Any
AnyTokenizer = Any AnyTokenizer = Any
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -191,61 +184,3 @@ class InputProcessingContext(InputContext): ...@@ -191,61 +184,3 @@ class InputProcessingContext(InputContext):
f"on data={data} with kwargs={allowed_kwargs}") f"on data={data} with kwargs={allowed_kwargs}")
raise ValueError(msg) from exc raise ValueError(msg) from exc
class DummyData(NamedTuple):
"""
Dummy data used for profiling.
Note: This is only used in V0.
"""
seq_data: SequenceData
multi_modal_data: Optional[MultiModalDataDict] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
class InputRegistry:
"""
Note: This is only used in V0.
"""
def dummy_data_for_profiling(
self,
model_config: ModelConfig,
seq_len: int,
mm_registry: MultiModalRegistry,
is_encoder_data: bool = False,
) -> DummyData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(model_config,
seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
multi_modal_data=dec_data.multi_modal_data.get_data(),
multi_modal_placeholders=dec_data.multi_modal_placeholders,
)
...@@ -3,13 +3,11 @@ ...@@ -3,13 +3,11 @@
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter) PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import (SamplingMetadata, from vllm.model_executor.sampling_metadata import SamplingMetadata
SamplingMetadataCache)
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"SamplingMetadata", "SamplingMetadata",
"SamplingMetadataCache",
"set_random_seed", "set_random_seed",
"BasevLLMParameter", "BasevLLMParameter",
"PackedvLLMParameter", "PackedvLLMParameter",
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that compute logits from hidden_stats.""" """A layer that compute logits from hidden_stats."""
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional
import torch import torch
import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather, from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -16,11 +13,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -16,11 +13,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform from vllm.platforms import current_platform
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
_logits_processor_threadpool = ThreadPoolExecutor(
envs.VLLM_LOGITS_PROCESSOR_THREADS)
@CustomOp.register("logits_processor") @CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp): class LogitsProcessor(CustomOp):
...@@ -60,15 +52,10 @@ class LogitsProcessor(CustomOp): ...@@ -60,15 +52,10 @@ class LogitsProcessor(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None, sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
prune_hidden_states: bool = True,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if self.logits_as_input: if self.logits_as_input:
logits = hidden_states logits = hidden_states
else: else:
if sampling_metadata is not None and prune_hidden_states:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias) logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None: if logits is not None:
...@@ -79,12 +66,6 @@ class LogitsProcessor(CustomOp): ...@@ -79,12 +66,6 @@ class LogitsProcessor(CustomOp):
if self.scale != 1.0: if self.scale != 1.0:
logits *= self.scale logits *= self.scale
# Apply logits processors (if any).
if sampling_metadata is not None and \
sampling_metadata.seq_groups is not None:
logits = _apply_logits_processors(logits, sampling_metadata)
return logits return logits
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
...@@ -125,75 +106,3 @@ class LogitsProcessor(CustomOp): ...@@ -125,75 +106,3 @@ class LogitsProcessor(CustomOp):
s += f", org_vocab_size={self.org_vocab_size}" s += f", org_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s return s
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
# (warmup, profile_run) we might not have selected_token_indices,
# so we skip pruning.
if sampling_metadata.selected_token_indices is not None:
return hidden_states.index_select(
0, sampling_metadata.selected_token_indices)
else:
return hidden_states
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
logits_row_ids_and_logits_row_futures = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
if _logits_processor_threadpool is not None:
logits_row_ids_and_logits_row_futures.append(
(logits_row_idx,
_logits_processor_threadpool.submit(
_apply_logits_processors_single_seq, logits_row,
logits_processors, past_tokens_ids,
prompt_tokens_ids)))
else:
logits[logits_row_idx] = \
_apply_logits_processors_single_seq(
logits_row, logits_processors, past_tokens_ids,
prompt_tokens_ids)
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
logits[logits_row_idx] = future.result()
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
return logits
def _apply_logits_processors_single_seq(logits_row, logits_processors,
past_tokens_ids,
prompt_tokens_ids) -> torch.Tensor:
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids, logits_row)
return logits_row
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment