Commit fe393be8 authored by 王敏's avatar 王敏
Browse files

[feat]支持v1 engine mtp cudagraph

parent 741dbbbb
......@@ -4776,6 +4776,11 @@ class VllmConfig:
if size <= max_num_tokens
]
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list))
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
......
......@@ -400,6 +400,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.use_spec_decode = False
self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
......@@ -496,17 +500,36 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.num_reqs == m.num_actual_tokens, \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
#m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
return self.build(0, m)
def build(self, common_prefix_len: int,
......@@ -633,10 +656,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens.device, non_blocking=True)
decode_seq_lens = decode_seq_lens - seq_lens_minus
decode_metadata = self._build_decode(
block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens,
)
if self.spec_decode_block_table_tensor is not None:
self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
decode_metadata = self._build_decode(
block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
)
else:
decode_metadata = self._build_decode(
block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens,
)
else:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
......@@ -658,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
#return common_attn_metadata.max_query_len == 1
return self._num_prefills == 0
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
......
......@@ -43,6 +43,8 @@ class CommonAttentionMetadata:
"""Longest query in batch"""
num_rejected_tokens: list[int] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0
"""Number of speculative tokens"""
M = TypeVar("M")
......
This diff is collapsed.
......@@ -29,10 +29,10 @@ PADDING_SLOT_ID = -1
class EagleProposer:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
......@@ -79,25 +79,25 @@ class EagleProposer:
dtype=torch.int32)
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
# [batch_size]
num_rejected_tokens: list[int],
# [batch_size]
sampling_metadata: SamplingMetadata
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
# [batch_size]
num_rejected_tokens: list[int],
# [batch_size]
sampling_metadata: SamplingMetadata
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
......@@ -168,7 +168,7 @@ class EagleProposer:
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
......@@ -212,7 +212,7 @@ class EagleProposer:
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
......@@ -259,7 +259,7 @@ class EagleProposer:
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
......@@ -267,10 +267,10 @@ class EagleProposer:
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
......@@ -311,11 +311,11 @@ class EagleProposer:
@staticmethod
def prepare_inputs(
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
......@@ -342,7 +342,7 @@ class EagleProposer:
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )](
prepare_eagle_input_kernel[(batch_size,)](
token_indices,
cu_target_query_lens,
cu_num_tokens,
......@@ -362,8 +362,8 @@ class EagleProposer:
model_config=draft_model_config)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
self.attn_layer_names = list(draft_attn_layer_names)
......@@ -376,8 +376,8 @@ class EagleProposer:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \
and self.method != "deepseek_mtp" \
and self.model.model.embed_tokens.weight.shape \
and self.method != "deepseek_mtp" \
and self.model.model.embed_tokens.weight.shape \
== target_language_model.model.embed_tokens.weight.shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \
......@@ -402,8 +402,8 @@ class EagleProposer:
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
......@@ -440,8 +440,8 @@ class EagleProposer:
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
......
......@@ -1989,6 +1989,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
if not is_profile and self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
......@@ -2008,12 +2012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
num_speculative_tokens=num_speculative_tokens,
)
for kv_cache_group_id, kv_cache_group_spec in enumerate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment