Commit 65e29a89 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[feat]支持v1 engine mtp cudagraph

See merge request dcutoolkit/deeplearing/vllm!164
parents 741dbbbb fe393be8
......@@ -4776,6 +4776,11 @@ class VllmConfig:
if size <= max_num_tokens
]
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list))
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
......
......@@ -400,6 +400,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.use_spec_decode = False
self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
......@@ -496,17 +500,36 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.num_reqs == m.num_actual_tokens, \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
#m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
return self.build(0, m)
def build(self, common_prefix_len: int,
......@@ -633,6 +656,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens.device, non_blocking=True)
decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None:
self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
decode_metadata = self._build_decode(
block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
)
else:
decode_metadata = self._build_decode(
block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens,
......@@ -658,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
#return common_attn_metadata.max_query_len == 1
return self._num_prefills == 0
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
......
......@@ -43,6 +43,8 @@ class CommonAttentionMetadata:
"""Longest query in batch"""
num_rejected_tokens: list[int] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0
"""Number of speculative tokens"""
M = TypeVar("M")
......
This diff is collapsed.
......@@ -342,7 +342,7 @@ class EagleProposer:
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )](
prepare_eagle_input_kernel[(batch_size,)](
token_indices,
cu_target_query_lens,
cu_num_tokens,
......
......@@ -1989,6 +1989,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
if not is_profile and self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
......@@ -2008,12 +2012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
num_speculative_tokens=num_speculative_tokens,
)
for kv_cache_group_id, kv_cache_group_spec in enumerate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment