Commit 93fae6b1 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]1.支持mtp模型 full_cuda_graph; 2.优化mtp拒绝采样

parent 3d57a0d3
...@@ -152,7 +152,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -152,7 +152,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return logits return logits
#@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -57,6 +59,9 @@ class EagleProposer: ...@@ -57,6 +59,9 @@ class EagleProposer:
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
self.use_full_cuda_graph = (
self.use_cuda_graph
and vllm_config.compilation_config.full_cuda_graph)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes)) self.vllm_config.compilation_config.cudagraph_capture_sizes))
...@@ -72,6 +77,8 @@ class EagleProposer: ...@@ -72,6 +77,8 @@ class EagleProposer:
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
# attention metadata captured in full cudagraph mode
self.attn_metadata_cudagraph = None
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
...@@ -131,6 +138,38 @@ class EagleProposer: ...@@ -131,6 +138,38 @@ class EagleProposer:
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if (self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -147,11 +186,15 @@ class EagleProposer: ...@@ -147,11 +186,15 @@ class EagleProposer:
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list = [draft_prob]
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1:
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, 1), draft_probs_list
# TODO: Currently, MTP module released by deepseek only has # TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once # one layer. Adapt this code to support multiple layers once
...@@ -191,7 +234,7 @@ class EagleProposer: ...@@ -191,7 +234,7 @@ class EagleProposer:
seq_lens=(seq_lens + 1), seq_lens=(seq_lens + 1),
) )
for _ in range(self.num_speculative_tokens - 1): for i in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
...@@ -242,6 +285,43 @@ class EagleProposer: ...@@ -242,6 +285,43 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
...@@ -265,10 +345,15 @@ class EagleProposer: ...@@ -265,10 +345,15 @@ class EagleProposer:
# TODO(wenlong): get more than one token for tree attention # TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list.append(draft_prob)
# [batch_size, num_speculative_tokens] # [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
return draft_token_ids, draft_probs
def prepare_inputs( def prepare_inputs(
self, self,
...@@ -418,8 +503,13 @@ class EagleProposer: ...@@ -418,8 +503,13 @@ class EagleProposer:
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]]
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
self.model( self.model(
self.input_ids[:num_tokens], self.input_ids[:num_tokens],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
from abc import ABC
import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -12,3 +16,41 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: ...@@ -12,3 +16,41 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
or sampling_params.repetition_penalty != 1.0 or sampling_params.repetition_penalty != 1.0
or sampling_params.min_p > _SAMPLING_EPS or sampling_params.min_p > _SAMPLING_EPS
or sampling_params.logprobs is not None) or sampling_params.logprobs is not None)
class DraftProbs(ABC): # type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs: torch.Tensor
# The request id list.
_req_ids: list[str]
def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs)
self.draft_probs = draft_probs
self._req_ids = req_ids
def update(self,
draft_probs: torch.Tensor,
tmp_req_ids: list[str]):
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
self._req_ids = diff_req_ids
self.draft_probs = self.draft_probs[index]
self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids.extend(tmp_req_ids)
assert len(self._req_ids) == len(self.draft_probs)
def prune(self, req_ids: list[str]):
new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids]
if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences.
index = [self._req_ids.index(req_id) for req_id in new_req_ids]
self.draft_probs = self.draft_probs[index]
self._req_ids = new_req_ids
def get_probs(self, req_ids: list[str]):
index = [self._req_ids.index(req_id) for req_id in req_ids]
return self.draft_probs[index]
...@@ -60,11 +60,13 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ...@@ -60,11 +60,13 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.rejection_sampler_mtp import MtpRejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -194,7 +196,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -194,7 +196,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
raise ValueError("Unknown speculative decoding method: " raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}") f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
self.use_mtp = self.speculative_config.method == "deepseek_mtp"
if not self.use_mtp:
self.rejection_sampler = RejectionSampler()
else:
self.rejection_sampler = MtpRejectionSampler()
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
...@@ -320,6 +328,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -320,6 +328,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# means this layer will perform attention using the keys and values # means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`. # from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {} self.shared_kv_cache_layers: dict[str, str] = {}
self.draft_probs : Optional[DraftProbs] = None
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
...@@ -379,6 +389,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -379,6 +389,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None) self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None) self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch. # Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and # NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and # scheduled_req_ids overlap. This happens when a request is aborted and
...@@ -387,6 +398,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -387,6 +398,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# and handling the second as a new request. # and handling the second as a new request.
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
# prune draft probs of finished requests
if self.use_mtp and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
self.draft_probs.prune(list(scheduler_output.finished_req_ids))
# Free the cached encoder outputs. # Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids: for req_id, input_id in scheduler_output.free_encoder_input_ids:
...@@ -1541,7 +1556,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1541,7 +1556,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits = logits[spec_decode_metadata.target_logits_indices] target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler( output_token_ids = self.rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
None, # draft_probs self.draft_probs.get_probs(self.input_batch.req_ids) \
if self.draft_probs is not None else None, # draft_probs
target_logits, target_logits,
bonus_token_ids, bonus_token_ids,
sampling_metadata, sampling_metadata,
...@@ -1627,7 +1643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1627,7 +1643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = None spec_token_ids = None
else: else:
assert spec_decode_common_attn_metadata is not None assert spec_decode_common_attn_metadata is not None
spec_token_ids = self.propose_draft_token_ids( spec_token_ids, draft_probs = self.propose_draft_token_ids(
scheduler_output, scheduler_output,
valid_sampled_token_ids, valid_sampled_token_ids,
sampling_metadata, sampling_metadata,
...@@ -1637,6 +1653,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1637,6 +1653,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_decode_metadata, spec_decode_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
) )
if self.use_mtp:
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
spec_token_ids = spec_token_ids.tolist()
self.eplb_step() self.eplb_step()
...@@ -1743,7 +1767,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1743,7 +1767,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1) [h[token_indices] for h in aux_hidden_states], dim=-1)
else: else:
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
draft_token_ids = self.drafter.propose( spec_token_ids, draft_probs = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
...@@ -1752,8 +1776,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1752,8 +1776,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
num_rejected_tokens=num_rejected_tokens num_rejected_tokens=num_rejected_tokens
) )
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return spec_token_ids, draft_probs
@staticmethod @staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
...@@ -2200,7 +2224,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2200,7 +2224,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens, attn_metadata)
# This is necessary to avoid blocking DP. # This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real # For dummy runs, we typically skip EPLB since we don't have any real
...@@ -2267,10 +2291,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2267,10 +2291,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids, self.device) draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids) num_tokens = sum(len(ids) for ids in draft_token_ids)
# draft_probs = torch.randn( draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device, num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype) dtype=logits.dtype)
draft_probs = None # draft_probs = None
target_logits = torch.randn(num_tokens, target_logits = torch.randn(num_tokens,
logits.shape[-1], logits.shape[-1],
device=self.device, device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment