Commit cc6f327a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[feat]支持mtp模型full_cuda_graph

See merge request dcutoolkit/deeplearing/vllm!172
parents a15d668b bd58c289
...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
#@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -647,13 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -647,13 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
rarange = np.repeat(query_lens, query_lens) - arange - 1 rarange = np.repeat(query_lens, query_lens) - arange - 1
repeats = torch.from_numpy(query_lens).pin_memory().to( repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True) block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave( decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...], block_table_tensor[:self._num_decodes, ...],
repeats, dim=0) repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0) decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to( seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True) seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None: if self.spec_decode_block_table_tensor is not None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -29,10 +32,10 @@ PADDING_SLOT_ID = -1 ...@@ -29,10 +32,10 @@ PADDING_SLOT_ID = -1
class EagleProposer: class EagleProposer:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
runner=None, runner=None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
...@@ -56,6 +59,9 @@ class EagleProposer: ...@@ -56,6 +59,9 @@ class EagleProposer:
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
self.use_full_cuda_graph = (
self.use_cuda_graph
and vllm_config.compilation_config.full_cuda_graph)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes)) self.vllm_config.compilation_config.cudagraph_capture_sizes))
...@@ -71,6 +77,9 @@ class EagleProposer: ...@@ -71,6 +77,9 @@ class EagleProposer:
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
# attention metadata captured in full cudagraph mode
self.attn_metadata_cudagraph = None
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
...@@ -79,25 +88,25 @@ class EagleProposer: ...@@ -79,25 +88,25 @@ class EagleProposer:
dtype=torch.int32) dtype=torch.int32)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
target_token_ids: torch.Tensor, target_token_ids: torch.Tensor,
# [num_tokens] # [num_tokens]
target_positions: torch.Tensor, target_positions: torch.Tensor,
# [num_tokens, hidden_size] # [num_tokens, hidden_size]
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
# [num_tokens] # [num_tokens]
target_slot_mapping: torch.Tensor, target_slot_mapping: torch.Tensor,
# [batch_size] # [batch_size]
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0 # [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor, cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req] # [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor, block_table: torch.Tensor,
# [batch_size] # [batch_size]
num_rejected_tokens: list[int], num_rejected_tokens: list[int],
# [batch_size] # [batch_size]
sampling_metadata: SamplingMetadata sampling_metadata: SamplingMetadata
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
...@@ -157,7 +166,7 @@ class EagleProposer: ...@@ -157,7 +166,7 @@ class EagleProposer:
# FIXME: need to consider multiple kv_cache_groups # FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build( attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata
) )
else: else:
raise ValueError(f"Unsupported method: {self.method}") raise ValueError(f"Unsupported method: {self.method}")
...@@ -168,7 +177,7 @@ class EagleProposer: ...@@ -168,7 +177,7 @@ class EagleProposer:
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \ if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
...@@ -176,6 +185,38 @@ class EagleProposer: ...@@ -176,6 +185,38 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if (self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
...@@ -190,6 +231,7 @@ class EagleProposer: ...@@ -190,6 +231,7 @@ class EagleProposer:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
...@@ -227,10 +269,10 @@ class EagleProposer: ...@@ -227,10 +269,10 @@ class EagleProposer:
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...] block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode( attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table, block_table_tensor=block_table,
seq_lens=(seq_lens + 1), seq_lens=seq_lens,
) )
for _ in range(self.num_speculative_tokens - 1): for i in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
...@@ -267,10 +309,10 @@ class EagleProposer: ...@@ -267,10 +309,10 @@ class EagleProposer:
# Compute the slot mapping. # Compute the slot mapping.
block_numbers = clamped_positions // self.block_size block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1, block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1)) index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1) block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size + attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size) clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length. # Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the # Otherwise, the KV cache will be inadvertently updated with the
# padding tokens. # padding tokens.
...@@ -282,6 +324,43 @@ class EagleProposer: ...@@ -282,6 +324,43 @@ class EagleProposer:
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -307,15 +386,16 @@ class EagleProposer: ...@@ -307,15 +386,16 @@ class EagleProposer:
# [batch_size, num_speculative_tokens] # [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids return draft_token_ids
@staticmethod @staticmethod
def prepare_inputs( def prepare_inputs(
# [batch_size + 1] # [batch_size + 1]
cu_target_query_lens: torch.Tensor, cu_target_query_lens: torch.Tensor,
# [batch_size] # [batch_size]
num_rejected_tokens: torch.Tensor, num_rejected_tokens: torch.Tensor,
num_tokens: int, num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c] # cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3] # num_rejected_tokens: [n1, n2, n3]
...@@ -342,7 +422,7 @@ class EagleProposer: ...@@ -342,7 +422,7 @@ class EagleProposer:
) )
batch_size = num_rejected_tokens.shape[0] batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size,)]( prepare_eagle_input_kernel[(batch_size, )](
token_indices, token_indices,
cu_target_query_lens, cu_target_query_lens,
cu_num_tokens, cu_num_tokens,
...@@ -362,8 +442,8 @@ class EagleProposer: ...@@ -362,8 +442,8 @@ class EagleProposer:
model_config=draft_model_config) model_config=draft_model_config)
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() - get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names) target_attn_layer_names)
self.attn_layer_names = list(draft_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names)
...@@ -376,8 +456,8 @@ class EagleProposer: ...@@ -376,8 +456,8 @@ class EagleProposer:
target_language_model = target_model target_language_model = target_model
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \ if get_pp_group().world_size == 1 \
and self.method != "deepseek_mtp" \ and self.method != "deepseek_mtp" \
and self.model.model.embed_tokens.weight.shape \ and self.model.model.embed_tokens.weight.shape \
== target_language_model.model.embed_tokens.weight.shape: == target_language_model.model.embed_tokens.weight.shape:
logger.info( logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \ "Assuming the EAGLE head shares the same vocab embedding" \
...@@ -402,10 +482,15 @@ class EagleProposer: ...@@ -402,10 +482,15 @@ class EagleProposer:
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]]
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
self.model( self.model(
self.input_ids[:num_tokens], self.input_ids[:num_tokens],
...@@ -440,8 +525,8 @@ class EagleProposer: ...@@ -440,8 +525,8 @@ class EagleProposer:
# FIXME(woosuk): The logic here is duplicated with the main sampling code. # FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation. # We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token( def compute_probs_and_sample_next_token(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling. # For greedy requests, draft_probs is not used in rejection sampling.
......
...@@ -2083,7 +2083,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2083,7 +2083,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens, attn_metadata)
# This is necessary to avoid blocking DP. # This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real # For dummy runs, we typically skip EPLB since we don't have any real
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment