Commit 04b61f0e authored by 王敏's avatar 王敏
Browse files

[fix]修复v1 mtp接受率低的问题

parent dea49b15
...@@ -546,6 +546,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -546,6 +546,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device = self.runner.device device = self.runner.device
block_table = self.block_table block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs] block_table_tensor = block_table.get_device_tensor()[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_actual_tokens].copy_( block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens], block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True) non_blocking=True)
......
...@@ -45,6 +45,8 @@ class CommonAttentionMetadata: ...@@ -45,6 +45,8 @@ class CommonAttentionMetadata:
"""(batch_size,), record the rejected tokens number in cpu and gpu""" """(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0 num_speculative_tokens: int = 0
"""Number of speculative tokens""" """Number of speculative tokens"""
slot_mapping: torch.Tensor = None
"""(batch_size, seq_len), slot mapping"""
M = TypeVar("M") M = TypeVar("M")
......
...@@ -159,6 +159,7 @@ class EagleProposer: ...@@ -159,6 +159,7 @@ class EagleProposer:
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
num_rejected_tokens=num_rejected_tokens, num_rejected_tokens=num_rejected_tokens,
slot_mapping=target_slot_mapping
) )
assert self.runner is not None assert self.runner is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment