Unverified Commit 26e673fe authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 Sequence class & Sampler (#25332)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent 65a5910c
...@@ -2,18 +2,15 @@ ...@@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from .utils import maybe_prefix from .utils import maybe_prefix
...@@ -105,8 +102,10 @@ class Medusa(nn.Module): ...@@ -105,8 +102,10 @@ class Medusa(nn.Module):
return [block(hidden_states) for block in self.blocks] return [block(hidden_states) for block in self.blocks]
def compute_logits( def compute_logits(
self, hidden_states: list[torch.Tensor], self,
sampling_metadata: SamplingMetadata) -> list[torch.Tensor]: hidden_states: list[torch.Tensor],
sampling_metadata,
) -> list[torch.Tensor]:
logits_lst: list[torch.Tensor] = [] logits_lst: list[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads): for hs, lm_head in zip(hidden_states, self.lm_heads):
...@@ -130,57 +129,6 @@ class Medusa(nn.Module): ...@@ -130,57 +129,6 @@ class Medusa(nn.Module):
return logits_lst return logits_lst
def sample(
self,
logits: list[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now
probs = torch.softmax(logits, dim=-1)
token_id_list = []
token_prob_list = []
token_logprob_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(token_ids[:, seq_group.sample_indices])
token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
outputs: list[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_prob_list[idx].squeeze(1),
logprobs=token_logprob_list[idx].squeeze(1),
sampled_token_ids=token_id_list[idx].squeeze(1),
))
return outputs
def generate_proposals(
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
# During preemption, we may receive an empty tensor (batch_size=0)
if previous_hidden_states.size(0) == 0:
# Return None to signal the Top1Proposer that no proposals
# were generated for this batch, allowing it to handle this
# special case appropriately
return None
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
),
sampling_metadata=sampling_metadata,
)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -8,9 +8,7 @@ import torch ...@@ -8,9 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -141,55 +139,57 @@ class MLPSpeculator(nn.Module): ...@@ -141,55 +139,57 @@ class MLPSpeculator(nn.Module):
self.config = config self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0) config.vocab_size, 1.0)
self.sampler = get_sampler()
def generate_proposals( # NOTE(woosuk): This method is commented out because it is old code
self, # using V0. We should either port it to V1 or remove it.
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
if self.scale_input: # def generate_proposals(
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 # self,
# input_ids: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# num_predict_tokens: int,
# sampling_metadata: SamplingMetadata,
# ) -> list[SamplerOutput]:
# if num_predict_tokens > self.max_speculative_tokens:
# raise ValueError(f"Max speculative tokens for model is "
# f"{self.max_speculative_tokens}, but "
# f"{num_predict_tokens} were requested")
# # b x 1 x d
# previous_hidden_states = previous_hidden_states.unsqueeze(1)
# if self.scale_input:
# previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# b x 1 # # b x 1
last_tokens = input_ids.unsqueeze(1) # last_tokens = input_ids.unsqueeze(1)
next_tokens = [] # next_tokens = []
for head_index in range(num_predict_tokens): # for head_index in range(num_predict_tokens):
# Project and predict # # Project and predict
z = self.emb[head_index](last_tokens) # b k d # z = self.emb[head_index](last_tokens) # b k d
states = self.proj[head_index](previous_hidden_states) # states = self.proj[head_index](previous_hidden_states)
# Weighted add of state_weight*state and emb_weight*z # # Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator # # Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues # # state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight) # states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d # states = self.activation(self.ln[head_index](states)) # b k d
previous_hidden_states = states # previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head # # TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1) # states = states.flatten(0, 1)
logits = self.logits_processor(self.head[head_index], states, # logits = self.logits_processor(self.head[head_index], states,
sampling_metadata) # sampling_metadata)
output = self.sampler(logits, sampling_metadata) # output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids # last_tokens = output.sampled_token_ids
next_tokens.append(output) # next_tokens.append(output)
return next_tokens # return next_tokens
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
......
...@@ -697,16 +697,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): ...@@ -697,16 +697,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
# If the shape is the same, it means that we have already
# prune hidden states manually.
prune_hidden_states = hidden_states.size(
0) != sampling_metadata.selected_token_indices.size(0)
processed_logits = self.logits_processor( processed_logits = self.logits_processor(
self.lm_head, self.lm_head,
hidden_states, hidden_states,
sampling_metadata, sampling_metadata,
self.embedding_bias, self.embedding_bias,
prune_hidden_states=prune_hidden_states) )
return processed_logits return processed_logits
def load_weights( def load_weights(
......
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from vllm.logprobs import Logprob
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
SequenceGroup)
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[
int, Logprob]]],
position_offset: int) -> None:
"""Decodes the logprobs for the prompt of a sequence group.
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
Returns:
The prompt logprobs with the decoded tokens.
"""
prms = seq_group.sampling_params
assert prms is not None
# We can pick any sequence for the prompt.
seq = seq_group.get_seqs()[0]
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens: list[str] = []
prev_tokens = None
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
prompt_logprobs):
# Absolute token position equals the index in the logprobs
# list plus the offset of the entire logprobs list relative
# to the start of the sequence.
token_position = token_position_in_logprob + position_offset
if not prompt_logprobs_for_token:
continue
for token_id, sample_logprob in prompt_logprobs_for_token.items():
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
prompt_token_ids_with_token = (
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
# Use the offsets & prev tokens corresponding to
# real tokens to ensure detokenization is consistent
# actual with prompt.
if token_id == all_token_ids[token_position]:
next_iter_prefix_offset = new_prefix_offset
next_iter_read_offset = new_read_offset
next_iter_tokens = new_tokens
# Advance to the next token position.
prefix_offset = next_iter_prefix_offset
read_offset = next_iter_read_offset
if prev_tokens is None:
prev_tokens = next_iter_tokens.copy()
else:
prev_tokens.extend(next_iter_tokens)
def decode_sequence_inplace(self, seq: Sequence,
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
return len(new_decoded_token_text)
...@@ -11,12 +11,12 @@ import torch.nn as nn ...@@ -11,12 +11,12 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method, resolve_obj_by_qualname, run_method,
update_environment_variables, update_environment_variables,
warn_for_unimplemented_methods) warn_for_unimplemented_methods)
from vllm.v1.outputs import SamplerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment