Unverified Commit 053f3b63 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] Spec decode rejection sampler logprobs support (#37237)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 5f82706a
......@@ -3,11 +3,14 @@
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
@triton.jit
......@@ -418,6 +421,26 @@ def probabilistic_rejection_sample(
return sampled, rejected_steps + 1
@triton.jit
def _flatten_sampled_kernel(
# [num_logits]
flat_sampled_ptr,
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
num_sampled_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_sampled = tl.load(num_sampled_ptr + req_idx)
for i in range(num_sampled):
token_id = tl.load(sampled_ptr + req_idx * sampled_stride + i)
tl.store(flat_sampled_ptr + start_idx + i, token_id)
class RejectionSampler:
def __init__(
self,
......@@ -429,6 +452,40 @@ class RejectionSampler:
self.num_speculative_steps = num_speculative_steps
self.use_strict_rejection_sampling = use_strict_rejection_sampling
def _get_logprobs_tensors(
self,
input_batch: InputBatch,
sampled: torch.Tensor,
num_sampled: torch.Tensor,
logits: torch.Tensor,
) -> LogprobsTensors | None:
max_num_logprobs = self.sampler.sampling_states.max_num_logprobs(
input_batch.idx_mapping_np
)
if max_num_logprobs == NO_LOGPROBS:
return None
num_reqs = input_batch.cu_num_logits.shape[0] - 1
num_logits = logits.shape[0]
flat_sampled = torch.zeros(
num_logits, dtype=sampled.dtype, device=sampled.device
)
_flatten_sampled_kernel[(num_reqs,)](
flat_sampled,
sampled,
sampled.stride(0),
num_sampled,
input_batch.cu_num_logits,
num_warps=1,
)
expanded_logits = num_logits != input_batch.idx_mapping.shape[0]
return compute_topk_logprobs(
logits,
max_num_logprobs,
flat_sampled,
input_batch.cu_num_logits_np.tolist() if expanded_logits else None,
)
def __call__(
self,
logits: torch.Tensor,
......@@ -460,8 +517,6 @@ class RejectionSampler:
draft_sampled,
input_batch.expanded_local_pos,
)
# TODO (TheEpicDolphin): Return logprobs for sampled token ids.
logprobs_tensors = None
sampled, num_sampled = probabilistic_rejection_sample(
processed_logits,
draft_logits,
......@@ -475,6 +530,14 @@ class RejectionSampler:
self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps,
)
logprobs_tensors = self._get_logprobs_tensors(
input_batch,
sampled,
num_sampled,
processed_logits
if self.sampler.logprobs_mode == "processed_logprobs"
else logits,
)
return SamplerOutput(
sampled_token_ids=sampled,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment