Unverified Commit 26e673fe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 Sequence class & Sampler (#25332)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent 65a5910c
......@@ -48,10 +48,10 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
initialize_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.sequence import Logprob
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import set_default_torch_num_threads
......
......@@ -7,8 +7,8 @@ from typing import Optional
import pytest
from transformers import AutoModelForSpeechSeq2Seq
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.sequence import SampleLogprobs
from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput,
VllmRunner)
......
......@@ -12,10 +12,10 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode, rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
PromptImageInput, VllmRunner)
......
......@@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor
from vllm import SamplingParams, TextPrompt, TokensPrompt
from vllm.logprobs import Logprob, SampleLogprobs
from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs
from ....utils import VLLM_PATH, large_gpu_test
from ...utils import check_logprobs_close
......
......@@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
GenerationConfig, GenerationMixin)
from transformers.video_utils import VideoMetadata
from vllm.sequence import SampleLogprobs
from vllm.logprobs import SampleLogprobs
from vllm.utils import is_list_of
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
......
......@@ -12,7 +12,7 @@ from transformers import AutoModelForCausalLM
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import RunnerOption
from vllm.sequence import SampleLogprobs
from vllm.logprobs import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
......
......@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig, ModelDType, RunnerOption
from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from .registry import HF_EXAMPLE_MODELS
......
......@@ -8,10 +8,7 @@ import pytest
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
......@@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast):
assert decoded_text == ''
assert out_ids == [len(tokenizer)]
@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
tokenizer = get_tokenizer(
tokenizer_name,
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
trust_remote_code=False,
revision=None,
)
return Detokenizer(tokenizer)
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> list[int]:
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or []
return Sequence(
seq_id=0,
inputs=token_inputs(prompt_token_ids),
block_size=16,
)
def create_dummy_logprobs(
complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
return [{
token_id: Logprob(logprob=0.0),
token_id + 1: Logprob(logprob=0.1)
} for token_id in complete_sequence_token_ids]
def create_dummy_prompt_logprobs(
complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: list[Optional[dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
logprobs=2)
# Run sequentially.
seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token: list[str] = []
sequential_logprobs_text_other_token: list[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs):
seq.append_token_id(new_token, logprobs)
detokenizer.decode_sequence_inplace(seq, sampling_params)
sequential_logprobs_text_chosen_token.append(
seq.output_logprobs[-1][new_token].decoded_token)
sequential_logprobs_text_other_token.append(
seq.output_logprobs[-1][new_token + 1].decoded_token)
sequential_result = seq.output_text
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token)
if not skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
assert sequential_result == complete_sequence
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer, MistralTokenizer):
skip_special_tokens = False
else:
pytest.skip("MistralTokenizers don't support "
"skip_special_tokens=False")
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1)
# Run sequentially.
seq = create_sequence(complete_sequence_token_ids)
seq_group = SequenceGroup(request_id="1",
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group,
dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
1:] # type: ignore
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenizer = detokenizer.tokenizer
text_full = tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0],
skip_special_tokens=skip_special_tokens)
text = text_full[len(text_first):]
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that the first logprob is None.
assert text == "".join([
logprobs[token_id].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
......@@ -12,7 +12,7 @@ from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
MODEL = "ai21labs/Jamba-tiny-dev"
......
......@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ToolCall)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser)
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
......
......@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
......
......@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
......
......@@ -12,9 +12,9 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
STOP_STRINGS,
DummyOutputProcessorTestVectors,
MockEngineCore)
from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import (OutputProcessor,
......
......@@ -15,10 +15,10 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.tasks import SupportedTask
from vllm.utils import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
......
......@@ -17,12 +17,12 @@ from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
ray)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
from vllm.v1.outputs import SamplerOutput
if ray is not None:
from ray.actor import ActorHandle
......
......@@ -7,15 +7,7 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)
INPUT_REGISTRY = InputRegistry()
"""
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
target model.
"""
from .registry import InputContext, InputProcessingContext
__all__ = [
"DataPrompt",
......@@ -36,9 +28,6 @@ __all__ = [
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"INPUT_REGISTRY",
"DummyData",
"InputContext",
"InputProcessingContext",
"InputRegistry",
]
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
from typing import TYPE_CHECKING, Any, Union
import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
......@@ -15,16 +15,9 @@ from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
MultiModalRegistry)
from vllm.sequence import SequenceData
from vllm.transformers_utils.tokenizer import AnyTokenizer
else:
ModelConfig = Any
MultiModalDataDict = Any
MultiModalPlaceholderDict = Any
MultiModalRegistry = Any
SequenceData = Any
AnyTokenizer = Any
_T = TypeVar("_T")
......@@ -191,61 +184,3 @@ class InputProcessingContext(InputContext):
f"on data={data} with kwargs={allowed_kwargs}")
raise ValueError(msg) from exc
class DummyData(NamedTuple):
"""
Dummy data used for profiling.
Note: This is only used in V0.
"""
seq_data: SequenceData
multi_modal_data: Optional[MultiModalDataDict] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
class InputRegistry:
"""
Note: This is only used in V0.
"""
def dummy_data_for_profiling(
self,
model_config: ModelConfig,
seq_len: int,
mm_registry: MultiModalRegistry,
is_encoder_data: bool = False,
) -> DummyData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(model_config,
seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
multi_modal_data=dec_data.multi_modal_data.get_data(),
multi_modal_placeholders=dec_data.multi_modal_placeholders,
)
......@@ -3,13 +3,11 @@
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingMetadataCache)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
__all__ = [
"SamplingMetadata",
"SamplingMetadataCache",
"set_random_seed",
"BasevLLMParameter",
"PackedvLLMParameter",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that compute logits from hidden_stats."""
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import torch
import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.custom_op import CustomOp
......@@ -16,11 +13,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
_logits_processor_threadpool = ThreadPoolExecutor(
envs.VLLM_LOGITS_PROCESSOR_THREADS)
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
......@@ -60,15 +52,10 @@ class LogitsProcessor(CustomOp):
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
prune_hidden_states: bool = True,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
if sampling_metadata is not None and prune_hidden_states:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
......@@ -79,12 +66,6 @@ class LogitsProcessor(CustomOp):
if self.scale != 1.0:
logits *= self.scale
# Apply logits processors (if any).
if sampling_metadata is not None and \
sampling_metadata.seq_groups is not None:
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
......@@ -125,75 +106,3 @@ class LogitsProcessor(CustomOp):
s += f", org_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
# (warmup, profile_run) we might not have selected_token_indices,
# so we skip pruning.
if sampling_metadata.selected_token_indices is not None:
return hidden_states.index_select(
0, sampling_metadata.selected_token_indices)
else:
return hidden_states
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
logits_row_ids_and_logits_row_futures = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
if _logits_processor_threadpool is not None:
logits_row_ids_and_logits_row_futures.append(
(logits_row_idx,
_logits_processor_threadpool.submit(
_apply_logits_processors_single_seq, logits_row,
logits_processors, past_tokens_ids,
prompt_tokens_ids)))
else:
logits[logits_row_idx] = \
_apply_logits_processors_single_seq(
logits_row, logits_processors, past_tokens_ids,
prompt_tokens_ids)
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
logits[logits_row_idx] = future.result()
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
return logits
def _apply_logits_processors_single_seq(logits_row, logits_processors,
past_tokens_ids,
prompt_tokens_ids) -> torch.Tensor:
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids, logits_row)
return logits_row
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment