Commit 89626cfc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[feat]优化mtp/eagle的计算逻辑,减少第1层并行解码的计算重复

See merge request dcutoolkit/deeplearing/vllm!178
parents a6bf968b 619eb032
......@@ -637,12 +637,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = None
if self._num_decodes > 0:
if self.use_spec_decode:
if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
query_lens = self.num_scheduled_tokens_np[:self._num_decodes]
if common_attn_metadata.num_rejected_tokens is not None:
num_rejected_tokens = common_attn_metadata.num_rejected_tokens[:self._num_decodes]
query_lens = query_lens - np.array(num_rejected_tokens, dtype=np.int32)
self._num_decode_tokens -= sum(num_rejected_tokens)
cu_num_blocks = np.cumsum(query_lens)
virtual_batches = cu_num_blocks[-1]
block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
......@@ -672,10 +668,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens,
)
else:
self._num_decode_tokens = self._num_decodes
if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
decode_metadata = self._build_decode(
block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
)
else:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=seq_lens[:self._num_decode_tokens],
)
return self.metadata_cls(
......
......@@ -41,12 +41,11 @@ class CommonAttentionMetadata:
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
num_rejected_tokens: list[int] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0
"""Number of speculative tokens"""
slot_mapping: torch.Tensor = None
"""(batch_size, seq_len), slot mapping"""
spec_layer_decoding: bool = False
M = TypeVar("M")
......
......@@ -104,9 +104,8 @@ class EagleProposer:
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
# [batch_size]
num_rejected_tokens: list[int],
# [batch_size]
sampling_metadata: SamplingMetadata
sampling_metadata: SamplingMetadata,
decoding: bool = False,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
......@@ -158,8 +157,8 @@ class EagleProposer:
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
num_rejected_tokens=num_rejected_tokens,
slot_mapping=target_slot_mapping
slot_mapping=target_slot_mapping,
spec_layer_decoding=decoding
)
assert self.runner is not None
......@@ -186,7 +185,7 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if (self.use_full_cuda_graph
if (decoding and self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
......@@ -220,7 +219,8 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
num_tokens=num_input_tokens,
skip_cuda_graphs=not decoding):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
......@@ -390,45 +390,56 @@ class EagleProposer:
return draft_token_ids
# @staticmethod
# def prepare_inputs(
# # [batch_size + 1]
# cu_target_query_lens: torch.Tensor,
# # [batch_size]
# num_rejected_tokens: torch.Tensor,
# num_tokens: int,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# # cu_target_query_lens: [0, a, a + b, a + b + c]
# # num_rejected_tokens: [n1, n2, n3]
# # num_tokens_per_req: [a - n1, b - n2, c - n3]
# # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# # token_indices: [0, 1, ..., a - n1 - 1,
# # a, a + 1, ..., a + b - n2 - 1,
# # a + b, a + b + 1, ..., a + b + c - n3 - 1]
# # [0, a, a + b, a + b + c] -> [a, b, c]
# query_len_per_req = (cu_target_query_lens[1:] -
# cu_target_query_lens[:-1])
# # [a, b, c] -> [a - n1, b - n2, c - n3]
# num_tokens_per_req = query_len_per_req - num_rejected_tokens
# # [a - n1, b - n2, c - n3] ->
# # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# cu_num_tokens = torch.zeros_like(cu_target_query_lens)
# torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
# token_indices = torch.empty(
# num_tokens,
# dtype=torch.int32,
# device=cu_target_query_lens.device,
# )
# batch_size = num_rejected_tokens.shape[0]
# BLOCK_SIZE = 1024
# prepare_eagle_input_kernel[(batch_size, )](
# token_indices,
# cu_target_query_lens,
# cu_num_tokens,
# BLOCK_SIZE=BLOCK_SIZE,
# )
# return cu_num_tokens, token_indices
@staticmethod
def prepare_inputs(
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
num_accepted_tokens_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_target_query_lens.device,
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )](
token_indices,
cu_target_query_lens,
cu_num_tokens,
BLOCK_SIZE=BLOCK_SIZE,
)
cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device)
token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None:
......
......@@ -1659,7 +1659,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
block_table = None
num_rejected_tokens = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
......@@ -1675,21 +1674,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens_tensor = async_tensor_h2d(
num_rejected_tokens,
num_accepted_tokens = [len(s) - 1 for s in sampled_token_ids]
num_accepted_tokens_tensor = async_tensor_h2d(
num_accepted_tokens,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_rejected_tokens_tensor,
num_tokens,
num_accepted_tokens_tensor,
)
target_token_ids = self.input_ids[token_indices]
# TODO(woosuk): Support M-RoPE.
......@@ -1710,7 +1703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens=cu_num_tokens,
block_table=block_table,
sampling_metadata=sampling_metadata,
num_rejected_tokens=num_rejected_tokens
decoding=spec_decode_metadata is not None
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment