Commit db554c6c authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]修复v1 mtp接受率低的问题

更改默认的full _cuda_graph启动方式为false
parent e97b3191
...@@ -4135,7 +4135,7 @@ class CompilationConfig: ...@@ -4135,7 +4135,7 @@ class CompilationConfig:
are always used, it can set this to False. Otherwise, it should are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.""" internally managed buffer. Default is False."""
full_cuda_graph: bool = True full_cuda_graph: bool = False
"""whether to use a full cuda graph for the entire forward pass rather than """whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide flag cannot be used together with splitting_ops. This may provide
......
...@@ -219,6 +219,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -219,6 +219,7 @@ from vllm.v1.attention.backends.utils import (
get_per_layer_parameters, infer_global_hyperparameters, get_per_layer_parameters, infer_global_hyperparameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -413,6 +414,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -413,6 +414,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
block_table: BlockTable,
metadata_cls: Optional[type[M]] = None): metadata_cls: Optional[type[M]] = None):
self.metadata_cls = metadata_cls \ self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata if metadata_cls is not None else MLACommonMetadata
...@@ -485,6 +487,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -485,6 +487,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device, device=device,
) )
self.block_table = block_table
self.use_spec_decode = False self.use_spec_decode = False
self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32) self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
...@@ -632,8 +636,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -632,8 +636,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# function. We should avoid GPU -> CPU sync as much as possible because # function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels. # it blocks on all previous kernels.
device = self.device device = self.device
block_table = self.block_table
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_tokens].copy_(
block_table.slot_mapping_cpu[:num_tokens],
non_blocking=True)
block_table.slot_mapping[num_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_tokens]
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
......
...@@ -57,12 +57,13 @@ class CommonAttentionMetadata: ...@@ -57,12 +57,13 @@ class CommonAttentionMetadata:
"""Longest query in batch""" """Longest query in batch"""
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
num_rejected_tokens: list[int] = None num_rejected_tokens: list[int] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu""" """(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0 num_speculative_tokens: int = 0
"""Number of speculative tokens""" """Number of speculative tokens"""
slot_mapping: torch.Tensor = None
"""(batch_size, seq_len), slot mapping"""
M = TypeVar("M") M = TypeVar("M")
......
...@@ -144,16 +144,7 @@ class EagleProposer: ...@@ -144,16 +144,7 @@ class EagleProposer:
if (self.use_full_cuda_graph if (self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]): and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]: if self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = ( self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens) attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = ( self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
...@@ -288,21 +279,7 @@ class EagleProposer: ...@@ -288,21 +279,7 @@ class EagleProposer:
if (self.use_full_cuda_graph if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]): and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]: if self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = ( self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens) attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = ( self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment