Commit 29646389 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-wm' into 'v0.15.1-dev'

[feat]deepseek mtp支持pp模式

See merge request dcutoolkit/deeplearing/vllm!503
parents ba2f6226 2ce72b9c
...@@ -37,6 +37,7 @@ from .deepseek_v2 import ( ...@@ -37,6 +37,7 @@ from .deepseek_v2 import (
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from .utils import maybe_prefix from .utils import maybe_prefix
from .interfaces import SupportsPP
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs import vllm.envs as envs
...@@ -194,7 +195,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -194,7 +195,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
......
...@@ -1902,7 +1902,7 @@ class GPUModelRunner( ...@@ -1902,7 +1902,7 @@ class GPUModelRunner(
cm.block_table_tensor = _get_block_table(kv_cache_gid) cm.block_table_tensor = _get_block_table(kv_cache_gid)
cm.slot_mapping = slot_mappings[kv_cache_gid] cm.slot_mapping = slot_mappings[kv_cache_gid]
if self.speculative_config and spec_decode_common_attn_metadata is None: if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
...@@ -4840,35 +4840,36 @@ class GPUModelRunner( ...@@ -4840,35 +4840,36 @@ class GPUModelRunner(
self.speculative_config.use_eagle() self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model() or self.speculative_config.uses_draft_model()
): ):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer) #assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert self.speculative_config is not None if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer | DraftModelProposer):
# Eagle currently only supports PIECEWISE cudagraphs. assert self.speculative_config is not None
# Therefore only use cudagraphs if the main model uses PIECEWISE # Eagle currently only supports PIECEWISE cudagraphs.
# NOTE(lucas): this is a hack, need to clean up. # Therefore only use cudagraphs if the main model uses PIECEWISE
use_cudagraphs = ( # NOTE(lucas): this is a hack, need to clean up.
( use_cudagraphs = (
is_graph_capturing (
and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE is_graph_capturing
) and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
or ( )
not is_graph_capturing or (
and cudagraph_runtime_mode != CUDAGraphMode.NONE not is_graph_capturing
and cudagraph_runtime_mode != CUDAGraphMode.NONE
)
) and not self.speculative_config.enforce_eager
# Note(gnovack) - We need to disable cudagraphs for one of the two
# lora cases when cudagraph_specialize_lora is enabled. This is a
# short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/28334
if self.compilation_config.cudagraph_specialize_lora and activate_lora:
use_cudagraphs = False
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
is_graph_capturing=is_graph_capturing,
slot_mappings=slot_mappings,
) )
) and not self.speculative_config.enforce_eager
# Note(gnovack) - We need to disable cudagraphs for one of the two
# lora cases when cudagraph_specialize_lora is enabled. This is a
# short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/28334
if self.compilation_config.cudagraph_specialize_lora and activate_lora:
use_cudagraphs = False
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
is_graph_capturing=is_graph_capturing,
slot_mappings=slot_mappings,
)
# We register layerwise NVTX hooks here after the first dynamo tracing is # We register layerwise NVTX hooks here after the first dynamo tracing is
# done to avoid nvtx operations in hook functions being traced by # done to avoid nvtx operations in hook functions being traced by
...@@ -5544,7 +5545,7 @@ class GPUModelRunner( ...@@ -5544,7 +5545,7 @@ class GPUModelRunner(
) )
# Initialize eagle's cudagraph dispatcher if using eagle spec decode. # Initialize eagle's cudagraph dispatcher if using eagle spec decode.
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle() and hasattr(self, "drafter"):
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.initialize_cudagraph_keys(cudagraph_mode) self.drafter.initialize_cudagraph_keys(cudagraph_mode)
...@@ -6091,10 +6092,11 @@ class GPUModelRunner( ...@@ -6091,10 +6092,11 @@ class GPUModelRunner(
self.speculative_config.use_eagle() self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model() or self.speculative_config.uses_draft_model()
): ):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer) #assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
# validate all draft model layers belong to the same kv cache if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer | DraftModelProposer):
# group # validate all draft model layers belong to the same kv cache
self.drafter.validate_same_kv_cache_group(kv_cache_config) # group
self.drafter.validate_same_kv_cache_group(kv_cache_config)
if has_kv_transfer_group(): if has_kv_transfer_group():
kv_transfer_group = get_kv_transfer_group() kv_transfer_group = get_kv_transfer_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment