Unverified Commit d55e446d authored by qizixi's avatar qizixi Committed by GitHub
Browse files

[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)


Signed-off-by: default avatarqizixi <qizixi@meta.com>
parent ec82c3e3
......@@ -100,8 +100,12 @@ def test_prepare_inputs():
dtype=torch.int32,
device=device)
# n1 + n2 + n3 - a - b -c
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
).item()
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
cu_target_query_lens, num_rejected_tokens)
cu_target_query_lens, num_rejected_tokens, num_tokens)
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
......
......@@ -271,6 +271,7 @@ class EagleProposer:
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
......@@ -288,18 +289,13 @@ class EagleProposer:
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.empty_like(cu_target_query_lens)
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
device=cu_target_query_lens.device,
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )](
......
......@@ -34,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
is_pin_memory_available)
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
......@@ -281,7 +281,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
......@@ -1360,9 +1360,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
next_token_ids = async_tensor_h2d(next_token_ids,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
......@@ -1390,14 +1391,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens_tensor = async_tensor_h2d(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
target_device=self.device,
pin_memory=True)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_rejected_tokens,
num_rejected_tokens_tensor,
num_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
......@@ -1408,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment