Commit 65e29a89 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[feat]支持v1 engine mtp cudagraph

See merge request dcutoolkit/deeplearing/vllm!164
parents 741dbbbb fe393be8
...@@ -4776,6 +4776,11 @@ class VllmConfig: ...@@ -4776,6 +4776,11 @@ class VllmConfig:
if size <= max_num_tokens if size <= max_num_tokens
] ]
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list))
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -400,6 +400,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -400,6 +400,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.use_spec_decode = False self.use_spec_decode = False
self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32) self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
...@@ -496,17 +500,36 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -496,17 +500,36 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA. Currently, only decode is supported for full cudagraphs with MLA.
""" """
m = common_attn_metadata m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \ # assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \ # "MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq." # "Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only #m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch. # Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0 self._num_prefills = 0
self._num_prefill_tokens = 0 self._num_prefill_tokens = 0
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
return self.build(0, m) return self.build(0, m)
def build(self, common_prefix_len: int, def build(self, common_prefix_len: int,
...@@ -633,10 +656,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -633,10 +656,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens.device, non_blocking=True) seq_lens.device, non_blocking=True)
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
decode_metadata = self._build_decode( if self.spec_decode_block_table_tensor is not None:
block_table_tensor=decode_block_table_tensor, self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
seq_lens=decode_seq_lens, self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
)
decode_metadata = self._build_decode(
block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
)
else:
decode_metadata = self._build_decode(
block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens,
)
else: else:
decode_metadata = self._build_decode( decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...], block_table_tensor=block_table_tensor[:self._num_decodes, ...],
...@@ -658,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -658,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def can_run_in_cudagraph( def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool: self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1 #return common_attn_metadata.max_query_len == 1
return self._num_prefills == 0
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
......
...@@ -43,6 +43,8 @@ class CommonAttentionMetadata: ...@@ -43,6 +43,8 @@ class CommonAttentionMetadata:
"""Longest query in batch""" """Longest query in batch"""
num_rejected_tokens: list[int] = None num_rejected_tokens: list[int] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu""" """(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens: int = 0
"""Number of speculative tokens"""
M = TypeVar("M") M = TypeVar("M")
......
This diff is collapsed.
...@@ -29,10 +29,10 @@ PADDING_SLOT_ID = -1 ...@@ -29,10 +29,10 @@ PADDING_SLOT_ID = -1
class EagleProposer: class EagleProposer:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
runner=None, runner=None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
...@@ -79,25 +79,25 @@ class EagleProposer: ...@@ -79,25 +79,25 @@ class EagleProposer:
dtype=torch.int32) dtype=torch.int32)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
target_token_ids: torch.Tensor, target_token_ids: torch.Tensor,
# [num_tokens] # [num_tokens]
target_positions: torch.Tensor, target_positions: torch.Tensor,
# [num_tokens, hidden_size] # [num_tokens, hidden_size]
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
# [num_tokens] # [num_tokens]
target_slot_mapping: torch.Tensor, target_slot_mapping: torch.Tensor,
# [batch_size] # [batch_size]
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0 # [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor, cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req] # [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor, block_table: torch.Tensor,
# [batch_size] # [batch_size]
num_rejected_tokens: list[int], num_rejected_tokens: list[int],
# [batch_size] # [batch_size]
sampling_metadata: SamplingMetadata sampling_metadata: SamplingMetadata
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
...@@ -168,7 +168,7 @@ class EagleProposer: ...@@ -168,7 +168,7 @@ class EagleProposer:
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \ if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
...@@ -212,7 +212,7 @@ class EagleProposer: ...@@ -212,7 +212,7 @@ class EagleProposer:
hidden_states = hidden_states[last_token_indices] hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \ if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]: batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
...@@ -259,7 +259,7 @@ class EagleProposer: ...@@ -259,7 +259,7 @@ class EagleProposer:
# Consider max model length. # Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len) self.max_model_len)
# For the requests that exceed the max model length, we set the # For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention. # sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
...@@ -267,10 +267,10 @@ class EagleProposer: ...@@ -267,10 +267,10 @@ class EagleProposer:
# Compute the slot mapping. # Compute the slot mapping.
block_numbers = clamped_positions // self.block_size block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1, block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1)) index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1) block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size + attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size) clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length. # Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the # Otherwise, the KV cache will be inadvertently updated with the
# padding tokens. # padding tokens.
...@@ -311,11 +311,11 @@ class EagleProposer: ...@@ -311,11 +311,11 @@ class EagleProposer:
@staticmethod @staticmethod
def prepare_inputs( def prepare_inputs(
# [batch_size + 1] # [batch_size + 1]
cu_target_query_lens: torch.Tensor, cu_target_query_lens: torch.Tensor,
# [batch_size] # [batch_size]
num_rejected_tokens: torch.Tensor, num_rejected_tokens: torch.Tensor,
num_tokens: int, num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c] # cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3] # num_rejected_tokens: [n1, n2, n3]
...@@ -342,7 +342,7 @@ class EagleProposer: ...@@ -342,7 +342,7 @@ class EagleProposer:
) )
batch_size = num_rejected_tokens.shape[0] batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )]( prepare_eagle_input_kernel[(batch_size,)](
token_indices, token_indices,
cu_target_query_lens, cu_target_query_lens,
cu_num_tokens, cu_num_tokens,
...@@ -362,8 +362,8 @@ class EagleProposer: ...@@ -362,8 +362,8 @@ class EagleProposer:
model_config=draft_model_config) model_config=draft_model_config)
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() - get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names) target_attn_layer_names)
self.attn_layer_names = list(draft_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names)
...@@ -376,8 +376,8 @@ class EagleProposer: ...@@ -376,8 +376,8 @@ class EagleProposer:
target_language_model = target_model target_language_model = target_model
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \ if get_pp_group().world_size == 1 \
and self.method != "deepseek_mtp" \ and self.method != "deepseek_mtp" \
and self.model.model.embed_tokens.weight.shape \ and self.model.model.embed_tokens.weight.shape \
== target_language_model.model.embed_tokens.weight.shape: == target_language_model.model.embed_tokens.weight.shape:
logger.info( logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \ "Assuming the EAGLE head shares the same vocab embedding" \
...@@ -402,8 +402,8 @@ class EagleProposer: ...@@ -402,8 +402,8 @@ class EagleProposer:
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
...@@ -440,8 +440,8 @@ class EagleProposer: ...@@ -440,8 +440,8 @@ class EagleProposer:
# FIXME(woosuk): The logic here is duplicated with the main sampling code. # FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation. # We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token( def compute_probs_and_sample_next_token(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling. # For greedy requests, draft_probs is not used in rejection sampling.
......
...@@ -1989,6 +1989,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1989,6 +1989,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs = self.scheduler_config.max_num_seqs max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs) num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs min_tokens_per_req = num_tokens // num_reqs
if not is_profile and self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens assert sum(num_scheduled_tokens_list) == num_tokens
...@@ -2008,12 +2012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2008,12 +2012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
non_blocking=True) non_blocking=True)
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=num_tokens, max_query_len=num_tokens,
num_speculative_tokens=num_speculative_tokens,
) )
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_group_id, kv_cache_group_spec in enumerate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment