Commit 89eecc55 authored by 王敏's avatar 王敏
Browse files

[feat]支持mtp模型full_cuda_graph

parent a1239b53
...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
#@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -68,6 +68,7 @@ from vllm.v1.worker.block_table import BlockTable ...@@ -68,6 +68,7 @@ from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
...@@ -191,7 +192,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -191,7 +192,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
raise ValueError("Unknown speculative decoding method: " raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}") f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler() self.rejection_sampler = RejectionSampler()
# Request states. # Request states.
...@@ -1362,6 +1362,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1362,6 +1362,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
inputs_embeds, scheduler_output, intermediate_tensors)
else:
# Run the model. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with set_forward_context( with set_forward_context(
...@@ -1688,7 +1694,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1688,7 +1694,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected_tokens=num_rejected_tokens num_rejected_tokens=num_rejected_tokens
) )
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return spec_token_ids
def kv_connector_no_forward( def kv_connector_no_forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment