Commit 89eecc55 authored by 王敏's avatar 王敏
Browse files

[feat]支持mtp模型full_cuda_graph

parent a1239b53
......@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata)
return logits
#@support_torch_compile
@support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
......@@ -68,6 +68,7 @@ from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
......@@ -191,7 +192,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
# Request states.
......@@ -1362,6 +1362,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
inputs_embeds, scheduler_output, intermediate_tensors)
else:
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
......@@ -1688,7 +1694,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected_tokens=num_rejected_tokens
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def kv_connector_no_forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment