Commit 9dd945c1 authored by 王敏's avatar 王敏
Browse files

[feat]支持mtp模型full_cuda_graph

parent 7e71c143
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from functools import cached_property
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
PLACEHOLDER_TOKEN_ID = -1
class MtpRejectionSampler(RejectionSampler):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def __init__(self):
super().__init__()
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
assert draft_probs is not None
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
num_draft_tokens = metadata.num_draft_tokens[0]
target_probs = compute_probs(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
num_draft_tokens
)
target_probs = target_probs.view(-1, num_draft_tokens, target_probs.shape[-1])
draft_probs = draft_probs.view(-1, num_draft_tokens, draft_probs.shape[-1])
draft_token_ids = metadata.draft_token_ids.view(-1, num_draft_tokens)
accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
target_probs,
draft_probs,
draft_token_ids,
None,
))
output_token_ids = self._create_output(
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
return output_token_ids
def _create_uniform_samples(self,
seeded_seqs: Optional[dict[int,
torch.Generator]],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if not seeded_seqs:
return torch.rand(batch_size, k + 1, device=device)
uniform_rand = torch.empty(batch_size, k + 1, device=device)
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(1,
k + 1,
dtype=self.probs_dtype,
device=device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k + 1,
dtype=self.probs_dtype,
device=device)
return uniform_rand
def _get_accepted(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
$\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
same conditional probability according to the draft model, the token
is accepted with probability:
$$
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
$$
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size, k, _ = draft_probs.shape
batch_indices = torch.arange(batch_size,
device=target_probs.device)[:, None]
probs_indices = torch.arange(k, device=target_probs.device)
# shape [batch_size, k]
selected_draft_probs = draft_probs[batch_indices, probs_indices,
draft_token_ids]
# shape [batch_size, k]
selected_target_probs = target_probs[batch_indices, probs_indices,
draft_token_ids]
uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
k - 1, target_probs.device)
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1, ), 1, device=target_probs.device))
accepted = uniform_rand < capped_ratio
return accepted
def _get_recovered_probs(
self,
target_probs: torch.Tensor, # [k, vocab_size]
draft_probs: torch.Tensor, # [k, vocab_size]
) -> torch.Tensor:
r"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
$x$ given context $x_1, \dots, x_n$ according to the target
model and $p(x|x_1, \dots, x_n)$, the same conditional probability
according to the draft model:
$$
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
$$
where $(f(x))_+$ is defined as:
$$
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
$$
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note:
This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_, k, _ = draft_probs.shape
# shape [batch_size, k, vocab_size]
difference = target_probs - draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f = torch.clamp(difference, min=self._smallest_positive_value)
# shape [batch_size, k, vocab_size]
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
return recovered_probs
def _batch_modified_rejection_sampling(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[dict[int, torch.Generator]],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size, k, vocab_size = target_probs.shape
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, seeded_seqs)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)
return accepted, recovered_token_ids
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
substitute_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via sampling, all subsequent token ids are
set to -1 for the sequence.
Args:
accepted: A boolean tensor indicating if the corresponding
draft token in draft_token_ids should be accepted or not.
substitute_token_ids: A tensor of token_ids that can be used
as substitutes for the draft token ids if the proposed token
is rejected.
draft_token_ids: A tensor of token ids speculated by the
draft model.
bonus_token_ids: Token ids to use as the bonus token if
all the draft tokens are accepted.
Returns:
A tensor containing the accepted token ids. The shape of the
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))
return output_with_bonus_tokens
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
]
return outputs
@cached_property
def _smallest_positive_value(self) -> float:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return torch.finfo(self.probs_dtype).tiny
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
seeded_seqs: dict[int, torch.Generator],
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs)
if not seeded_seqs:
q.exponential_(1.0)
else:
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
# Note: generator might be None for non seeded
q[start:end].exponential_(1.0, generator=generator)
start = end
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
def compute_probs(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
spec_len: int
) -> torch.Tensor:
"""Compute probability distribution from logits based on sampling metadata.
This function applies temperature scaling to the logits and converts
them to probabilities using softmax. For greedy decoding, it returns
the original logits.
Args:
logits: Input logits tensor to be converted to probabilities.
cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used.
Returns:
torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the
original logits.
"""
assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1
if sampling_metadata.all_greedy:
return logits
# num_tokens = logits.shape[0]
temperature = sampling_metadata.temperature.view(-1, 1).repeat(1, spec_len).view(-1)
temperature = torch.where(temperature > 0, temperature, 1)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors.
top_k = None
if sampling_metadata.top_k is not None:
top_k = sampling_metadata.top_k.view(-1, 1).repeat(1, spec_len).view(-1)
top_p = None
if sampling_metadata.top_p is not None:
top_p = sampling_metadata.top_p.view(-1, 1).repeat(1, spec_len).view(-1)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits = apply_top_k_top_p(logits, top_k, top_p)
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
return output_prob
...@@ -107,7 +107,7 @@ class EagleProposer: ...@@ -107,7 +107,7 @@ class EagleProposer:
num_rejected_tokens: list[int], num_rejected_tokens: list[int],
# [batch_size] # [batch_size]
sampling_metadata: SamplingMetadata sampling_metadata: SamplingMetadata
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1 last_token_indices = cu_num_tokens[1:] - 1
...@@ -231,16 +231,13 @@ class EagleProposer: ...@@ -231,16 +231,13 @@ class EagleProposer:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list = [draft_prob] draft_token_ids = torch.argmax(logits, dim=-1)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1:
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, draft_prob.shape[-1]) return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has # TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once # one layer. Adapt this code to support multiple layers once
...@@ -257,7 +254,7 @@ class EagleProposer: ...@@ -257,7 +254,7 @@ class EagleProposer:
hidden_states = hidden_states[last_token_indices] hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \ if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]: batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
...@@ -383,18 +380,14 @@ class EagleProposer: ...@@ -383,18 +380,14 @@ class EagleProposer:
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
None) None)
# TODO(wenlong): get more than one token for tree attention # # TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list.append(draft_prob)
# [batch_size, num_speculative_tokens] # [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
return draft_token_ids, draft_probs return draft_token_ids
@staticmethod @staticmethod
def prepare_inputs( def prepare_inputs(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
from abc import ABC
import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -43,41 +39,3 @@ def prepare_eagle_input_kernel( ...@@ -43,41 +39,3 @@ def prepare_eagle_input_kernel(
index_start + offset, index_start + offset,
mask=offset < num_tokens, mask=offset < num_tokens,
) )
class DraftProbs(ABC): # type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs: torch.Tensor
# The request id list.
_req_ids: list[str]
def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs)
self.draft_probs = draft_probs
self._req_ids = req_ids
def update(self,
draft_probs: torch.Tensor,
tmp_req_ids: list[str]):
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
self._req_ids = diff_req_ids
self.draft_probs = self.draft_probs[index]
self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids.extend(tmp_req_ids)
assert len(self._req_ids) == len(self.draft_probs)
def prune(self, req_ids: list[str]):
new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids]
if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences.
index = [self._req_ids.index(req_id) for req_id in new_req_ids]
self.draft_probs = self.draft_probs[index]
self._req_ids = new_req_ids
def get_probs(self, req_ids: list[str]):
index = [self._req_ids.index(req_id) for req_id in req_ids]
return self.draft_probs[index]
...@@ -58,13 +58,11 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ...@@ -58,13 +58,11 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.rejection_sampler_mtp import MtpRejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -194,11 +192,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -194,11 +192,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise ValueError("Unknown speculative decoding method: " raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}") f"{self.speculative_config.method}")
self.use_mtp = self.speculative_config.method == "deepseek_mtp" self.rejection_sampler = RejectionSampler()
if not self.use_mtp:
self.rejection_sampler = RejectionSampler()
else:
self.rejection_sampler = MtpRejectionSampler()
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
...@@ -325,8 +319,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -325,8 +319,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`. # from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {} self.shared_kv_cache_layers: dict[str, str] = {}
self.draft_probs : Optional[DraftProbs] = None
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -386,10 +378,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -386,10 +378,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
# prune draft probs of finished requests
if self.use_mtp and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
self.draft_probs.prune(list(scheduler_output.finished_req_ids))
# Free the cached encoder outputs. # Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids: for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id) encoder_outputs = self.encoder_cache.get(req_id)
...@@ -547,7 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -547,7 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = ( spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids: if spec_token_ids:
num_spec_tokens = len(spec_token_ids) num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index] start_index = self.input_batch.num_tokens_no_spec[req_index]
...@@ -1465,8 +1452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1465,8 +1452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits = logits[spec_decode_metadata.target_logits_indices] target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler( output_token_ids = self.rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
self.draft_probs.get_probs(self.input_batch.req_ids) \ None, # draft_probs
if self.draft_probs is not None else None, # draft_probs
target_logits, target_logits,
bonus_token_ids, bonus_token_ids,
sampling_metadata, sampling_metadata,
...@@ -1551,7 +1537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1551,7 +1537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
else: else:
spec_token_ids, draft_probs = self.propose_draft_token_ids( spec_token_ids = self.propose_draft_token_ids(
scheduler_output, scheduler_output,
valid_sampled_token_ids, valid_sampled_token_ids,
sampling_metadata, sampling_metadata,
...@@ -1562,13 +1548,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1562,13 +1548,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
) )
if self.use_mtp:
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
spec_token_ids = spec_token_ids.tolist() spec_token_ids = spec_token_ids.tolist()
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
...@@ -1587,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1587,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pooler_output=[], pooler_output=[],
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits num_nans_in_logits=num_nans_in_logits,
) )
def propose_draft_token_ids( def propose_draft_token_ids(
...@@ -1600,8 +1579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1600,8 +1579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor], aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata], spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any], attn_metadata: dict[str, Any],
) -> tuple[list[list[int]], torch.Tensor]: ) -> list[list[int]]:
draft_probs = None
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
...@@ -1700,7 +1678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1700,7 +1678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[ target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices] token_indices]
spec_token_ids, draft_probs = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
...@@ -1711,8 +1689,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1711,8 +1689,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
num_rejected_tokens=num_rejected_tokens num_rejected_tokens=num_rejected_tokens
) )
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs return spec_token_ids
def kv_connector_no_forward( def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
...@@ -2168,10 +2147,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2168,10 +2147,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids, self.device) draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids) num_tokens = sum(len(ids) for ids in draft_token_ids)
draft_probs = torch.randn( # draft_probs = torch.randn(
num_tokens, logits.shape[-1], device=self.device, # num_tokens, logits.shape[-1], device=self.device,
dtype=logits.dtype) # dtype=logits.dtype)
# draft_probs = None draft_probs = None
target_logits = torch.randn(num_tokens, target_logits = torch.randn(num_tokens,
logits.shape[-1], logits.shape[-1],
device=self.device, device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment