Commit 4874e3e0 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]优化mtp/eagle的计算逻辑,减少第1层并行解码的计算重复(num_accepted_tokens_tensor修改暂未合入)

parent 295dfac8
......@@ -754,12 +754,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = None
if num_decodes > 0:
if self.use_spec_decode:
if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
query_lens = self.num_scheduled_tokens_np[:num_decodes]
if common_attn_metadata.num_rejected_tokens is not None:
num_rejected_tokens = common_attn_metadata.num_rejected_tokens[:num_decodes]
query_lens = query_lens - np.array(num_rejected_tokens, dtype=np.int32)
self._num_decode_tokens -= sum(num_rejected_tokens)
cu_num_blocks = np.cumsum(query_lens)
virtual_batches = cu_num_blocks[-1]
block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
......@@ -790,10 +786,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens=decode_seq_lens,
)
else:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
)
self._num_decode_tokens = num_decodes
if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
decode_metadata = self._build_decode(
block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
)
else:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=seq_lens[:self._num_decode_tokens],
)
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
......
......@@ -58,12 +58,11 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor
num_rejected_tokens: list[int] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0
"""Number of speculative tokens"""
slot_mapping: torch.Tensor = None
"""(batch_size, seq_len), slot mapping"""
spec_layer_decoding: bool = False
M = TypeVar("M")
......
......@@ -98,6 +98,7 @@ class EagleProposer:
next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
decoding: bool = False,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
......@@ -141,7 +142,7 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if (self.use_full_cuda_graph
if (decoding and self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method == "deepseek_mtp":
......@@ -166,7 +167,8 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
num_tokens=num_input_tokens,
skip_cuda_graphs=not decoding):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
......@@ -329,9 +331,11 @@ class EagleProposer:
def prepare_inputs(
self,
# cu_target_query_lens: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
# [batch_size]
num_rejected_tokens: torch.Tensor
num_rejected_tokens: torch.Tensor,
# num_accepted_tokens_tensor: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for the spec decode.
......@@ -403,6 +407,7 @@ class EagleProposer:
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)
# token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device,
......
......@@ -1732,7 +1732,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32,
device=self.device)
num_rejected_tokens = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
......@@ -1757,6 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.drafter.prepare_inputs(
common_attn_metadata, num_rejected_tokens_cpu)
# num_accepted_tokens = [len(s) - 1 for s in sampled_token_ids]
# num_accepted_tokens_tensor = async_tensor_h2d(
# num_accepted_tokens,
# dtype=torch.int32,
# target_device=self.device,
# pin_memory=True)
# num_accepted_tokens_cpu = torch.tensor(num_accepted_tokens,
# dtype=torch.int32)
# common_attn_metadata, token_indices =\
# self.drafter.prepare_inputs(
# common_attn_metadata, num_accepted_tokens_cpu)
target_token_ids = self.input_ids[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices]
......@@ -1772,7 +1784,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_token_ids=next_token_ids,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
num_rejected_tokens=num_rejected_tokens
decoding=spec_decode_metadata is not None
)
spec_token_ids = draft_token_ids.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment