Commit 04b61f0e authored by 王敏's avatar 王敏
Browse files

[fix]修复v1 mtp接受率低的问题

parent dea49b15
......@@ -546,11 +546,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
......
......@@ -45,6 +45,8 @@ class CommonAttentionMetadata:
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0
"""Number of speculative tokens"""
slot_mapping: torch.Tensor = None
"""(batch_size, seq_len), slot mapping"""
M = TypeVar("M")
......
......@@ -159,6 +159,7 @@ class EagleProposer:
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
num_rejected_tokens=num_rejected_tokens,
slot_mapping=target_slot_mapping
)
assert self.runner is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment