Commit 4874e3e0 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]优化mtp/eagle的计算逻辑,减少第1层并行解码的计算重复(num_accepted_tokens_tensor修改暂未合入)

parent 295dfac8
...@@ -754,12 +754,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -754,12 +754,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = None decode_metadata = None
if num_decodes > 0: if num_decodes > 0:
if self.use_spec_decode: if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
query_lens = self.num_scheduled_tokens_np[:num_decodes] query_lens = self.num_scheduled_tokens_np[:num_decodes]
if common_attn_metadata.num_rejected_tokens is not None:
num_rejected_tokens = common_attn_metadata.num_rejected_tokens[:num_decodes]
query_lens = query_lens - np.array(num_rejected_tokens, dtype=np.int32)
self._num_decode_tokens -= sum(num_rejected_tokens)
cu_num_blocks = np.cumsum(query_lens) cu_num_blocks = np.cumsum(query_lens)
virtual_batches = cu_num_blocks[-1] virtual_batches = cu_num_blocks[-1]
block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens) block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
...@@ -789,10 +785,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -789,10 +785,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
block_table_tensor=decode_block_table_tensor, block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens, seq_lens=decode_seq_lens,
) )
else:
self._num_decode_tokens = num_decodes
if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
decode_metadata = self._build_decode(
block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
)
else: else:
decode_metadata = self._build_decode( decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...], block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=seq_lens[:num_decodes], seq_lens=seq_lens[:self._num_decode_tokens],
) )
attn_metadata = self.metadata_cls( attn_metadata = self.metadata_cls(
......
...@@ -58,12 +58,11 @@ class CommonAttentionMetadata: ...@@ -58,12 +58,11 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
num_rejected_tokens: list[int] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0 num_speculative_tokens: int = 0
"""Number of speculative tokens""" """Number of speculative tokens"""
slot_mapping: torch.Tensor = None slot_mapping: torch.Tensor = None
"""(batch_size, seq_len), slot mapping""" """(batch_size, seq_len), slot mapping"""
spec_layer_decoding: bool = False
M = TypeVar("M") M = TypeVar("M")
......
...@@ -98,6 +98,7 @@ class EagleProposer: ...@@ -98,6 +98,7 @@ class EagleProposer:
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
decoding: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
...@@ -141,7 +142,7 @@ class EagleProposer: ...@@ -141,7 +142,7 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if (self.use_full_cuda_graph if (decoding and self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]): and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph assert self.attn_metadata_cudagraph
if self.method == "deepseek_mtp": if self.method == "deepseek_mtp":
...@@ -166,7 +167,8 @@ class EagleProposer: ...@@ -166,7 +167,8 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens,
skip_cuda_graphs=not decoding):
ret_hidden_states = self.model( ret_hidden_states = self.model(
self.input_ids[:num_input_tokens], self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens], self.positions[:num_input_tokens],
...@@ -329,9 +331,11 @@ class EagleProposer: ...@@ -329,9 +331,11 @@ class EagleProposer:
def prepare_inputs( def prepare_inputs(
self, self,
# cu_target_query_lens: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
# [batch_size] # [batch_size]
num_rejected_tokens: torch.Tensor num_rejected_tokens: torch.Tensor,
# num_accepted_tokens_tensor: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor]: ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
""" """
This function is used to prepare the inputs for the spec decode. This function is used to prepare the inputs for the spec decode.
...@@ -403,6 +407,7 @@ class EagleProposer: ...@@ -403,6 +407,7 @@ class EagleProposer:
token_indices_np = token_offests + old_query_start_locs_expanded token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to( token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True) device, non_blocking=True)
# token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
spec_common_attn_metadata = CommonAttentionMetadata( spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device, query_start_loc=new_query_start_loc_cpu.to(device,
......
...@@ -1732,7 +1732,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1732,7 +1732,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
num_rejected_tokens = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens] target_token_ids = self.input_ids[:num_scheduled_tokens]
...@@ -1757,6 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1757,6 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.drafter.prepare_inputs( self.drafter.prepare_inputs(
common_attn_metadata, num_rejected_tokens_cpu) common_attn_metadata, num_rejected_tokens_cpu)
# num_accepted_tokens = [len(s) - 1 for s in sampled_token_ids]
# num_accepted_tokens_tensor = async_tensor_h2d(
# num_accepted_tokens,
# dtype=torch.int32,
# target_device=self.device,
# pin_memory=True)
# num_accepted_tokens_cpu = torch.tensor(num_accepted_tokens,
# dtype=torch.int32)
# common_attn_metadata, token_indices =\
# self.drafter.prepare_inputs(
# common_attn_metadata, num_accepted_tokens_cpu)
target_token_ids = self.input_ids[token_indices] target_token_ids = self.input_ids[token_indices]
# TODO(woosuk): Support M-RoPE. # TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices] target_positions = self.positions[token_indices]
...@@ -1772,7 +1784,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1772,7 +1784,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
num_rejected_tokens=num_rejected_tokens decoding=spec_decode_metadata is not None
) )
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment