Commit 00e13357 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]支持v1 engine mtp cudagraph

parent 3de379de
...@@ -4802,6 +4802,11 @@ class VllmConfig: ...@@ -4802,6 +4802,11 @@ class VllmConfig:
if size <= max_num_tokens if size <= max_num_tokens
] ]
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list))
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -488,6 +488,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -488,6 +488,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.use_spec_decode = False self.use_spec_decode = False
self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32) self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc qo_indptr = prefill.query_start_loc
...@@ -589,11 +593,30 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -589,11 +593,30 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA. Currently, only decode is supported for full cudagraphs with MLA.
""" """
m = common_attn_metadata m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \ # assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \ # "MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq." # "Make sure all cudagraph capture sizes <= max_num_seq."
# m.max_query_len = 1 # decode-only
m.max_query_len = 1 # decode-only self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
return self.build(0, m) return self.build(0, m)
...@@ -742,6 +765,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -742,6 +765,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens.device, non_blocking=True) seq_lens.device, non_blocking=True)
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None:
self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
decode_metadata = self._build_decode(
block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
)
else:
decode_metadata = self._build_decode( decode_metadata = self._build_decode(
block_table_tensor=decode_block_table_tensor, block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens, seq_lens=decode_seq_lens,
...@@ -775,7 +807,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -775,7 +807,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def can_run_in_cudagraph( def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool: self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# return common_attn_metadata.max_query_len == 1
if not self.use_spec_decode:
return common_attn_metadata.max_query_len == 1 return common_attn_metadata.max_query_len == 1
return self._num_prefills == 0
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
......
...@@ -55,6 +55,8 @@ class CommonAttentionMetadata: ...@@ -55,6 +55,8 @@ class CommonAttentionMetadata:
"""Longest query in batch""" """Longest query in batch"""
num_rejected_tokens: list[int] num_rejected_tokens: list[int]
"""(batch_size,), record the rejected tokens number in cpu and gpu""" """(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0
"""Number of speculative tokens"""
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
......
This diff is collapsed.
...@@ -2091,6 +2091,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2091,6 +2091,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs = self.scheduler_config.max_num_seqs max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs) num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs min_tokens_per_req = num_tokens // num_reqs
if not is_profile and self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens assert sum(num_scheduled_tokens_list) == num_tokens
...@@ -2108,6 +2112,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2108,6 +2112,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups): self.kv_cache_config.kv_cache_groups):
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
...@@ -2121,6 +2127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2121,6 +2127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=num_tokens, max_query_len=num_tokens,
num_speculative_tokens=num_speculative_tokens,
block_table_tensor=self.input_batch.block_table[ block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs], kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch. slot_mapping=self.input_batch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment