Commit cf4be8ff authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.15.1-dev-wm' into 'v0.15.1-dev'

[feat]支持宽松mtp

See merge request dcutoolkit/deeplearing/vllm!414
parents 4a4fb3de aec90b84
......@@ -291,6 +291,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_W8A8_BACKEND: int = 3
VLLM_REJECT_SAMPLE_OPT: bool = False
def get_default_cache_root():
......@@ -1836,6 +1837,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# vllm will use optimized reject sample
"VLLM_REJECT_SAMPLE_OPT":
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'True').lower() in
("true", "1")),
}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from collections.abc import Sequence
from dataclasses import replace
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = 0
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 128
class OptRejectionSampler(nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def __init__(self, sampler: Sampler):
super().__init__()
self.sampler = sampler
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens + batch_size, vocab_size]
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
"""
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens + batch_size, vocab_size]. Here,
probabilities from different requests are flattened into a
single tensor because this is the shape of the output logits.
NOTE: `logits` can be updated in place to save memory.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
SamplerOutput:
Contains the final output token IDs and their logprobs if
requested.
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=replace(
sampling_metadata,
max_num_logprobs=-1,
),
predict_bonus_token=True,
# Override the logprobs mode to return logits because they are
# needed later to compute the accepted token logprobs.
logprobs_mode_override="processed_logits"
if self.is_processed_logprobs_mode
else "raw_logits",
)
target_logits = logits[target_logits_indices]
target_tokens = sampler_output.sampled_token_ids[target_logits_indices]
bonus_token_ids = sampler_output.sampled_token_ids[bonus_logits_indices]
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
target_tokens,
bonus_token_ids,
sampling_metadata,
)
logprobs_tensors = None
if sampling_metadata.max_num_logprobs is not None:
logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs,
metadata,
sampler_output.logprobs_tensors.logprobs,
output_token_ids,
)
return SamplerOutput(
sampled_token_ids=output_token_ids,
logprobs_tensors=logprobs_tensors,
)
def _get_logprobs_tensors(
self,
max_num_logprobs: int,
metadata: SpecDecodeMetadata,
logits: torch.Tensor,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
final_logits = logits.to(torch.float32)
# NOTE: To avoid cpu-gpu synchronization, we now simply compute indices for
# all draft tokens, including the rejected ones. The rejected tokens will
# be filtered out in the `parse_output`.
logit_start_indices = cu_num_sampled_tokens
offsets = torch.arange(
sampled_token_ids.shape[-1],
device=logit_start_indices.device,
dtype=logit_start_indices.dtype,
)
accepted_logit_indices = (
logit_start_indices.unsqueeze(1) + offsets.unsqueeze(0)
).flatten()
accepted_logit_indices.clamp_(max=final_logits.shape[0] - 1)
accepted_tokens = sampled_token_ids.clone().flatten()
# we replace rejected token ids with 0 to avoid gather_logprobs error
accepted_tokens[accepted_tokens == PLACEHOLDER_TOKEN_ID] = 0
# Compute logprobs for accepted tokens.
accepted_logits = final_logits[accepted_logit_indices]
accepted_logprobs = (
accepted_logits
if self.is_logits_logprobs_mode
else self.sampler.compute_logprobs(accepted_logits)
)
return self.sampler.gather_logprobs(
accepted_logprobs,
max_num_logprobs,
accepted_tokens.to(torch.int64),
)
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
discard_req_indices: Sequence[int] = (),
logprobs_tensors: LogprobsTensors | None = None,
) -> tuple[list[list[int]], LogprobsLists | None]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
logprobs_tensors: Optional logprobs tensors to filter.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
output_logprobs = None
if logprobs_tensors is not None:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
filtered_tensors = logprobs_tensors.filter(valid_mask.flatten())
output_logprobs = filtered_tensors.tolists(cu_num_tokens)
if len(discard_req_indices) > 0:
valid_mask[discard_req_indices] = False
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs, output_logprobs
def apply_logits_processors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
) -> torch.Tensor:
has_penalties = not sampling_metadata.no_penalties
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or has_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if any_penalties_or_bad_words:
output_token_ids = self._combine_outputs_with_spec_tokens(
output_token_ids,
sampling_metadata.spec_token_ids,
)
# Calculate indices of target logits.
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
num_requests = len(sampling_metadata.output_token_ids)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
repeat_indices = repeat_indices_cpu.to(
device=logits.device, non_blocking=True
)
logits = self.apply_penalties(
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
)
# Apply allowed token ids.
if sampling_metadata.allowed_token_ids_mask is not None:
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
logits.masked_fill_(token_mask, float("-inf"))
# Apply bad words exclusion.
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
apply_bad_words_with_drafts(
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
)
return logits
@staticmethod
def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
repeat_indices: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
logits = apply_all_penalties(
logits,
prompt_token_ids,
presence_penalties,
frequency_penalties,
repetition_penalties,
output_token_ids,
)
return logits
@staticmethod
def _combine_outputs_with_spec_tokens(
output_token_ids: list[list[int]],
spec_token_ids: list[list[int]] | None = None,
) -> list[list[int]]:
if spec_token_ids is None:
return output_token_ids
result = []
for out, spec in zip(output_token_ids, spec_token_ids):
if len(spec) == 0:
continue
result.append(out)
for i in range(len(spec) - 1):
result.append([*result[-1], spec[i]])
return result
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [num_tokens, vocab_size]
target_tokens,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 3
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.full(
(batch_size, max_spec_len + 1),
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
fill_value=PLACEHOLDER_TOKEN_ID,
device=device,
)
uniform_probs = torch.rand(
(num_tokens, ),
dtype=torch.float32,
device=device,
)
uniform_probs = uniform_probs * 0.1 + 0.1
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
target_tokens,
bonus_token_ids,
uniform_probs,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
target_token_ids_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
uniform_probs_ptr, # [num_tokens]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if draft_token_id < 0:
draft_token_id = 0
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
draft_token_id = draft_token_id.to(tl.int64)
target_token_id = tl.load(target_token_ids_ptr + (start_idx + pos))
target_token_id = target_token_id.to(tl.int64)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if (draft_token_id == target_token_id) or (target_prob / draft_prob >= uniform_prob and draft_prob > 0):
token_id = draft_token_id
else:
rejected = True
token_id = target_token_id
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
......@@ -8,6 +8,7 @@ import numpy as np
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import (
CUDAGraphMode,
VllmConfig,
......@@ -397,9 +398,16 @@ class SpecDecodeBaseProposer:
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
draft_token_ids = logits.argmax(dim=-1)
if envs.VLLM_REJECT_SAMPLE_OPT:
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, logits.shape[-1])
return draft_token_ids.view(-1, 1)
if self.uses_mrope:
......@@ -472,6 +480,9 @@ class SpecDecodeBaseProposer:
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs_list = [draft_prob]
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
......@@ -598,8 +609,17 @@ class SpecDecodeBaseProposer:
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list.append(draft_prob)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
return draft_token_ids, draft_probs
return draft_token_ids
def set_inputs_first_pass(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
......@@ -22,6 +23,8 @@ class SpecDecodeMetadata:
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
# [batch_size]
spec_decode_ids: Optional[list[str]] = None
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import async_tensor_h2d
@triton.jit
......@@ -107,3 +111,74 @@ def eagle_prepare_next_token_padded_kernel(
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
class DraftProbs(ABC): # type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs: torch.Tensor
# The request id list.
_req_ids: list[str] = []
count = 0
req_id_to_count: dict[str, int] = {}
prune_threshould = 100
def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs)
self.draft_probs = draft_probs
self._req_ids = req_ids
for req_id in req_ids:
self.req_id_to_count[req_id] = self.count
def update(self,
draft_probs: torch.Tensor,
tmp_req_ids: list[str]):
self.count += 1
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
index_tensor = async_tensor_h2d(
index,
dtype=torch.int32,
target_device=self.draft_probs.device,
pin_memory=True)
self.draft_probs = self.draft_probs[index_tensor]
self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids = diff_req_ids
self._req_ids.extend(tmp_req_ids)
for req_id in tmp_req_ids:
self.req_id_to_count[req_id] = self.count
assert len(self._req_ids) == len(self.draft_probs)
def prune(self, req_ids: list[str]):
if self.count % self.prune_threshould == 0:
for req_id, last_count in self.req_id_to_count.items():
if self.count - last_count >= self.prune_threshould:
req_ids.append(req_id)
self.req_id_to_count = {k: v for k, v in self.req_id_to_count.items() if k not in req_ids}
new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids]
if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences.
index = [self._req_ids.index(req_id) for req_id in new_req_ids]
index_tensor = async_tensor_h2d(
index,
dtype=torch.int32,
target_device=self.draft_probs.device,
pin_memory=True)
self.draft_probs = self.draft_probs[index_tensor]
self._req_ids = new_req_ids
def get_probs(self, req_ids: list[str]):
index = [self._req_ids.index(req_id) for req_id in req_ids]
index_tensor = async_tensor_h2d(
index,
dtype=torch.int32,
target_device=self.draft_probs.device,
pin_memory=True)
return self.draft_probs[index_tensor]
......@@ -12,7 +12,7 @@ from contextlib import contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
from functools import reduce
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast, Optional
import numpy as np
import torch
......@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.rejection_sampler_opt import OptRejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
......@@ -181,6 +182,7 @@ from .utils import (
bind_kv_cache,
sanity_check_mm_encoder_outputs,
)
from vllm.v1.spec_decode.utils import DraftProbs
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
......@@ -470,7 +472,11 @@ class GPUModelRunner(
"Unknown speculative decoding method: "
f"{self.speculative_config.method}"
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
self.rejection_sampler = RejectionSampler(self.sampler)
else:
self.rejection_sampler = OptRejectionSampler(self.sampler)
self.num_spec_tokens = 0
if self.speculative_config:
......@@ -702,6 +708,8 @@ class GPUModelRunner(
self.mamba_state_idx: dict[str, int] = {}
self.layerwise_nvtx_hooks_registered = False
self.draft_probs : Optional[DraftProbs] = None
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
if self.speculative_config:
......@@ -874,6 +882,10 @@ class GPUModelRunner(
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)
# prune draft probs of finished requests
if envs.VLLM_REJECT_SAMPLE_OPT and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
self.draft_probs.prune(list(scheduler_output.finished_req_ids))
# Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
......@@ -1616,8 +1628,13 @@ class GPUModelRunner(
>= self.input_batch.num_prompt_tokens[req_idx]
):
num_decode_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_ids = None
if envs.VLLM_REJECT_SAMPLE_OPT:
spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens
num_draft_tokens, cu_num_tokens, spec_decode_ids
)
logits_indices = spec_decode_metadata.logits_indices
num_sampled_tokens = num_draft_tokens + 1
......@@ -2118,6 +2135,7 @@ class GPUModelRunner(
self,
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
spec_decode_ids: Optional[list[str]] = None
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
......@@ -2191,6 +2209,7 @@ class GPUModelRunner(
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
spec_decode_ids=spec_decode_ids,
)
def _prepare_kv_sharing_fast_prefill(
......@@ -2838,7 +2857,8 @@ class GPUModelRunner(
sampler_output = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
None if self.draft_probs is None else \
self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids), # draft_probs
logits,
sampling_metadata,
)
......@@ -3999,7 +4019,7 @@ class GPUModelRunner(
else:
mm_embed_inputs = None
draft_token_ids = self.drafter.propose(
draft_result = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
......@@ -4012,6 +4032,19 @@ class GPUModelRunner(
slot_mappings=slot_mappings,
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
else:
draft_token_ids, draft_probs = draft_result
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:
......@@ -4651,6 +4684,9 @@ class GPUModelRunner(
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = self._init_model_kwargs()
else:
self.input_ids.gpu[:num_tokens_padded] = torch.randint(0, self.model_config.get_vocab_size(),
(num_tokens_padded,),
dtype=torch.int32)
input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None
......@@ -4836,7 +4872,14 @@ class GPUModelRunner(
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs = None
else:
draft_probs = torch.randn(
num_reqs, self.speculative_config.num_speculative_tokens, logits.shape[-1], device=self.device,
dtype=logits.dtype)
logits = torch.randn(
num_tokens + num_reqs,
logits.shape[-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment