Commit ade7db0c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm-1126' into 'v0.9.2-dev'

[feat]支持宽松mtp

See merge request dcutoolkit/deeplearing/vllm!269
parents 9aadeed6 b9bc84e2
......@@ -185,6 +185,7 @@ if TYPE_CHECKING:
VLLM_USE_ZERO_MTP: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_CAT_MLA: bool = False
VLLM_REJECT_SAMPLE_OPT: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1199,7 +1200,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use fused cat and mla
"VLLM_USE_CAT_MLA":
lambda: (os.getenv('VLLM_USE_CAT_MLA', 'False').lower() in
("true", "1")),
("true", "1")),
# vllm will use fused cat and mla
"VLLM_REJECT_SAMPLE_OPT":
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'False').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
class OptRejectionSampler(nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [num_tokens, vocab_size]
target_tokens: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
draft_token_ids = metadata.draft_token_ids
mask = draft_token_ids.eq(-1).to(torch.bool)
draft_token_ids = torch.where(mask, 0, draft_token_ids).to(torch.long) # 兼容第一次decode
output_token_ids = rejection_sample(
draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
target_tokens,
bonus_token_ids,
sampling_metadata,
)
return output_token_ids
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
]
return outputs
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [num_tokens, vocab_size]
target_tokens,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 3
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.full(
(batch_size, max_spec_len + 1),
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
fill_value=PLACEHOLDER_TOKEN_ID,
device=device,
)
uniform_probs = torch.rand(
(num_tokens, ),
dtype=torch.float32,
device=device,
)
uniform_probs = uniform_probs * 0.1 + 0.1
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
target_tokens,
bonus_token_ids,
uniform_probs,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
target_token_ids_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
uniform_probs_ptr, # [num_tokens]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_token_id = tl.load(target_token_ids_ptr + (start_idx + pos))
target_token_id = target_token_id.to(tl.int64)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if (draft_token_id == target_token_id) or (target_prob / draft_prob >= uniform_prob and draft_prob > 0):
token_id = draft_token_id
else:
rejected = True
token_id = target_token_id
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
......@@ -5,7 +5,9 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
......@@ -235,6 +237,10 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list = [draft_prob]
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
......@@ -385,9 +391,17 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list.append(draft_prob)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
return draft_token_ids, draft_probs
return draft_token_ids
# @staticmethod
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
......@@ -21,6 +22,8 @@ class SpecDecodeMetadata:
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
# [batch_size]
spec_decode_ids: Optional[list[str]] = None
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
from abc import ABC
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
......@@ -39,3 +43,42 @@ def prepare_eagle_input_kernel(
index_start + offset,
mask=offset < num_tokens,
)
class DraftProbs(ABC): # type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs: torch.Tensor
# The request id list.
_req_ids: list[str]
def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs)
self.draft_probs = draft_probs
self._req_ids = req_ids
def update(self,
draft_probs: torch.Tensor,
tmp_req_ids: list[str]):
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
self._req_ids = diff_req_ids
self.draft_probs = self.draft_probs[index]
self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids.extend(tmp_req_ids)
assert len(self._req_ids) == len(self.draft_probs)
def prune(self, req_ids: list[str]):
new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids]
if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences.
index = [self._req_ids.index(req_id) for req_id in new_req_ids]
self.draft_probs = self.draft_probs[index]
self._req_ids = new_req_ids
def get_probs(self, req_ids: list[str]):
index = [self._req_ids.index(req_id) for req_id in req_ids]
return self.draft_probs[index]
......@@ -59,6 +59,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.rejection_sampler_opt import OptRejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
......@@ -75,6 +76,7 @@ from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
from vllm.v1.spec_decode.utils import DraftProbs
if TYPE_CHECKING:
import xgrammar as xgr
......@@ -197,7 +199,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
if not envs.VLLM_REJECT_SAMPLE_OPT:
self.rejection_sampler = RejectionSampler()
else:
self.rejection_sampler = OptRejectionSampler()
# Request states.
self.requests: dict[str, CachedRequestState] = {}
......@@ -324,6 +329,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
self.draft_probs : Optional[DraftProbs] = None
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -383,6 +390,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)
# prune draft probs of finished requests
if envs.VLLM_REJECT_SAMPLE_OPT and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
self.draft_probs.prune(list(scheduler_output.finished_req_ids))
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
......@@ -762,13 +773,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_ids = None
if envs.VLLM_REJECT_SAMPLE_OPT:
spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
num_draft_tokens, cu_num_tokens, spec_decode_ids)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model
......@@ -922,6 +938,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self,
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
spec_decode_ids: Optional[list[str]] = None
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
......@@ -993,6 +1010,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
spec_decode_ids=spec_decode_ids,
)
return metadata
......@@ -1491,25 +1509,47 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
if not envs.VLLM_REJECT_SAMPLE_OPT:
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
else:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
target_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.target_logits_indices]
target_logits = logits[spec_decode_metadata.target_logits_indices]
bonus_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.bonus_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids),
target_logits,
target_token_ids,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
......@@ -1590,7 +1630,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Speculative decoding is not enabled.
spec_token_ids = None
else:
spec_token_ids = self.propose_draft_token_ids(
spec_result = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
......@@ -1600,6 +1640,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata,
attn_metadata,
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result
else:
spec_token_ids, draft_probs = spec_result
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
......@@ -1722,7 +1771,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
draft_token_ids = self.drafter.propose(
draft_result = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
......@@ -1733,7 +1782,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
decoding=spec_decode_metadata is not None
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
draft_token_ids, draft_probs = draft_result
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs
return spec_token_ids
def kv_connector_no_forward(
......@@ -2190,15 +2248,20 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else:
raise e
if self.speculative_config:
draft_token_ids = [[0] for _ in range(num_reqs)]
draft_token_ids = [[0]*self.speculative_config.num_lookahead_slots for _ in range(num_reqs)]
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids)
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs = None
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs = None
else:
draft_probs = torch.randn(
num_reqs, self.speculative_config.num_lookahead_slots, logits.shape[-1], device=self.device,
dtype=logits.dtype)
target_token_ids = torch.zeros(num_tokens, device=self.device,
dtype=torch.int32)
target_logits = torch.randn(num_tokens,
logits.shape[-1],
device=self.device,
......@@ -2209,13 +2272,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
bonus_token_ids = torch.zeros(num_reqs,
device=self.device,
dtype=torch.int32)
self.rejection_sampler(
dummy_spec_decode_metadata,
draft_probs,
target_logits,
bonus_token_ids,
dummy_metadata,
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
self.rejection_sampler(
dummy_spec_decode_metadata,
draft_probs,
target_logits,
bonus_token_ids,
dummy_metadata,
)
else:
self.rejection_sampler(
dummy_spec_decode_metadata,
draft_probs,
target_logits,
target_token_ids,
bonus_token_ids,
dummy_metadata,
)
return sampler_output
@torch.inference_mode()
......@@ -3050,8 +3123,12 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_ids = None
if envs.VLLM_REJECT_SAMPLE_OPT:
spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
num_draft_tokens, cu_num_tokens, spec_decode_ids)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model
......@@ -3258,25 +3335,46 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
if not envs.VLLM_REJECT_SAMPLE_OPT:
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
else:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
target_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.target_logits_indices]
target_logits = logits[spec_decode_metadata.target_logits_indices]
bonus_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.bonus_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids),
target_logits,
target_token_ids,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
......@@ -3325,7 +3423,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.zero_propose_draft_token_ids(
spec_result = self.zero_propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
......@@ -3336,6 +3435,15 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata,
attn_metadata,
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result
else:
spec_token_ids, draft_probs = spec_result
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
......@@ -3479,7 +3587,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
draft_token_ids = self.drafter.propose(
draft_result = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
......@@ -3494,7 +3602,14 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# self.last_draft_token_ids = draft_token_ids
# self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
# self.last_draft_event.record()
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
draft_token_ids, draft_probs = draft_result
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs
return spec_token_ids
#TODO:稳定后使用GPUModelRunnerMTP替换GPUModelRunner
if envs.VLLM_USE_ZERO_MTP:
......
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.forward_context import set_forward_context
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
......@@ -161,6 +164,10 @@ class V1ZeroEagleProposer(EagleProposer):
draft_token_ids = logits.argmax(dim=-1)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list = [draft_prob]
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
......@@ -311,7 +318,15 @@ class V1ZeroEagleProposer(EagleProposer):
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list.append(draft_prob)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
return draft_token_ids, draft_probs
return draft_token_ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment