Commit 9dd945c1 authored by 王敏's avatar 王敏
Browse files

[feat]支持mtp模型full_cuda_graph

parent 7e71c143
This diff is collapsed.
......@@ -107,7 +107,7 @@ class EagleProposer:
num_rejected_tokens: list[int],
# [batch_size]
sampling_metadata: SamplingMetadata
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
......@@ -231,16 +231,13 @@ class EagleProposer:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list = [draft_prob]
draft_token_ids = torch.argmax(logits, dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, draft_prob.shape[-1])
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
......@@ -257,7 +254,7 @@ class EagleProposer:
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
......@@ -383,18 +380,14 @@ class EagleProposer:
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# TODO(wenlong): get more than one token for tree attention
# # TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list.append(draft_prob)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
return draft_token_ids, draft_probs
return draft_token_ids
@staticmethod
def prepare_inputs(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
from abc import ABC
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
......@@ -43,41 +39,3 @@ def prepare_eagle_input_kernel(
index_start + offset,
mask=offset < num_tokens,
)
class DraftProbs(ABC): # type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs: torch.Tensor
# The request id list.
_req_ids: list[str]
def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs)
self.draft_probs = draft_probs
self._req_ids = req_ids
def update(self,
draft_probs: torch.Tensor,
tmp_req_ids: list[str]):
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
self._req_ids = diff_req_ids
self.draft_probs = self.draft_probs[index]
self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids.extend(tmp_req_ids)
assert len(self._req_ids) == len(self.draft_probs)
def prune(self, req_ids: list[str]):
new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids]
if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences.
index = [self._req_ids.index(req_id) for req_id in new_req_ids]
self.draft_probs = self.draft_probs[index]
self._req_ids = new_req_ids
def get_probs(self, req_ids: list[str]):
index = [self._req_ids.index(req_id) for req_id in req_ids]
return self.draft_probs[index]
......@@ -58,13 +58,11 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.rejection_sampler_mtp import MtpRejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
......@@ -194,11 +192,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.use_mtp = self.speculative_config.method == "deepseek_mtp"
if not self.use_mtp:
self.rejection_sampler = RejectionSampler()
else:
self.rejection_sampler = MtpRejectionSampler()
self.rejection_sampler = RejectionSampler()
# Request states.
self.requests: dict[str, CachedRequestState] = {}
......@@ -325,8 +319,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
self.draft_probs : Optional[DraftProbs] = None
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -386,10 +378,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)
# prune draft probs of finished requests
if self.use_mtp and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
self.draft_probs.prune(list(scheduler_output.finished_req_ids))
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
......@@ -547,7 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index]
......@@ -1465,8 +1452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
self.draft_probs.get_probs(self.input_batch.req_ids) \
if self.draft_probs is not None else None, # draft_probs
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
......@@ -1551,7 +1537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Speculative decoding is not enabled.
spec_token_ids = None
else:
spec_token_ids, draft_probs = self.propose_draft_token_ids(
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
......@@ -1562,13 +1548,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata,
)
if self.use_mtp:
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
spec_token_ids = spec_token_ids.tolist()
# Clear KVConnector state after all KVs are generated.
......@@ -1587,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits
num_nans_in_logits=num_nans_in_logits,
)
def propose_draft_token_ids(
......@@ -1600,8 +1579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any],
) -> tuple[list[list[int]], torch.Tensor]:
draft_probs = None
) -> list[list[int]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
......@@ -1700,7 +1678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
spec_token_ids, draft_probs = self.drafter.propose(
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
......@@ -1711,8 +1689,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
num_rejected_tokens=num_rejected_tokens
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs
return spec_token_ids
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
......@@ -2168,10 +2147,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids)
draft_probs = torch.randn(
num_tokens, logits.shape[-1], device=self.device,
dtype=logits.dtype)
# draft_probs = None
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs = None
target_logits = torch.randn(num_tokens,
logits.shape[-1],
device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment