Unverified Commit 26e673fe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 Sequence class & Sampler (#25332)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent 65a5910c
......@@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from .utils import maybe_prefix
......@@ -105,8 +102,10 @@ class Medusa(nn.Module):
return [block(hidden_states) for block in self.blocks]
def compute_logits(
self, hidden_states: list[torch.Tensor],
sampling_metadata: SamplingMetadata) -> list[torch.Tensor]:
self,
hidden_states: list[torch.Tensor],
sampling_metadata,
) -> list[torch.Tensor]:
logits_lst: list[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads):
......@@ -130,57 +129,6 @@ class Medusa(nn.Module):
return logits_lst
def sample(
self,
logits: list[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now
probs = torch.softmax(logits, dim=-1)
token_id_list = []
token_prob_list = []
token_logprob_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(token_ids[:, seq_group.sample_indices])
token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
outputs: list[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_prob_list[idx].squeeze(1),
logprobs=token_logprob_list[idx].squeeze(1),
sampled_token_ids=token_id_list[idx].squeeze(1),
))
return outputs
def generate_proposals(
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
# During preemption, we may receive an empty tensor (batch_size=0)
if previous_hidden_states.size(0) == 0:
# Return None to signal the Top1Proposer that no proposals
# were generated for this batch, allowing it to handle this
# special case appropriately
return None
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
),
sampling_metadata=sampling_metadata,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
......
......@@ -8,9 +8,7 @@ import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -141,55 +139,57 @@ class MLPSpeculator(nn.Module):
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0)
self.sampler = get_sampler()
def generate_proposals(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
# NOTE(woosuk): This method is commented out because it is old code
# using V0. We should either port it to V1 or remove it.
if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# def generate_proposals(
# self,
# input_ids: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# num_predict_tokens: int,
# sampling_metadata: SamplingMetadata,
# ) -> list[SamplerOutput]:
# if num_predict_tokens > self.max_speculative_tokens:
# raise ValueError(f"Max speculative tokens for model is "
# f"{self.max_speculative_tokens}, but "
# f"{num_predict_tokens} were requested")
# # b x 1 x d
# previous_hidden_states = previous_hidden_states.unsqueeze(1)
# if self.scale_input:
# previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# b x 1
last_tokens = input_ids.unsqueeze(1)
# # b x 1
# last_tokens = input_ids.unsqueeze(1)
next_tokens = []
# next_tokens = []
for head_index in range(num_predict_tokens):
# for head_index in range(num_predict_tokens):
# Project and predict
z = self.emb[head_index](last_tokens) # b k d
states = self.proj[head_index](previous_hidden_states)
# # Project and predict
# z = self.emb[head_index](last_tokens) # b k d
# states = self.proj[head_index](previous_hidden_states)
# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight)
# # Weighted add of state_weight*state and emb_weight*z
# # Let subsequent LN take care of denominator
# # state_weight is close to 1, so shouldn't be any precision issues
# states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d
previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
# states = self.activation(self.ln[head_index](states)) # b k d
# previous_hidden_states = states
# # TODO: not yet supporting top_k_tokens_per_head
# states = states.flatten(0, 1)
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
# logits = self.logits_processor(self.head[head_index], states,
# sampling_metadata)
output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)
# output = self.sampler(logits, sampling_metadata)
# last_tokens = output.sampled_token_ids
# next_tokens.append(output)
return next_tokens
# return next_tokens
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
......
......@@ -697,16 +697,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# If the shape is the same, it means that we have already
# prune hidden states manually.
prune_hidden_states = hidden_states.size(
0) != sampling_metadata.selected_token_indices.size(0)
processed_logits = self.logits_processor(
self.lm_head,
hidden_states,
sampling_metadata,
self.embedding_bias,
prune_hidden_states=prune_hidden_states)
)
return processed_logits
def load_weights(
......
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from vllm.logprobs import Logprob
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
SequenceGroup)
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[
int, Logprob]]],
position_offset: int) -> None:
"""Decodes the logprobs for the prompt of a sequence group.
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
Returns:
The prompt logprobs with the decoded tokens.
"""
prms = seq_group.sampling_params
assert prms is not None
# We can pick any sequence for the prompt.
seq = seq_group.get_seqs()[0]
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens: list[str] = []
prev_tokens = None
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
prompt_logprobs):
# Absolute token position equals the index in the logprobs
# list plus the offset of the entire logprobs list relative
# to the start of the sequence.
token_position = token_position_in_logprob + position_offset
if not prompt_logprobs_for_token:
continue
for token_id, sample_logprob in prompt_logprobs_for_token.items():
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
prompt_token_ids_with_token = (
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
# Use the offsets & prev tokens corresponding to
# real tokens to ensure detokenization is consistent
# actual with prompt.
if token_id == all_token_ids[token_position]:
next_iter_prefix_offset = new_prefix_offset
next_iter_read_offset = new_read_offset
next_iter_tokens = new_tokens
# Advance to the next token position.
prefix_offset = next_iter_prefix_offset
read_offset = next_iter_read_offset
if prev_tokens is None:
prev_tokens = next_iter_tokens.copy()
else:
prev_tokens.extend(next_iter_tokens)
def decode_sequence_inplace(self, seq: Sequence,
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
return len(new_decoded_token_text)
......@@ -11,12 +11,12 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method,
update_environment_variables,
warn_for_unimplemented_methods)
from vllm.v1.outputs import SamplerOutput
logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment