Commit cc6f327a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[feat]支持mtp模型full_cuda_graph

See merge request dcutoolkit/deeplearing/vllm!172
parents a15d668b bd58c289
...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
#@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -647,13 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -647,13 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
rarange = np.repeat(query_lens, query_lens) - arange - 1 rarange = np.repeat(query_lens, query_lens) - arange - 1
repeats = torch.from_numpy(query_lens).pin_memory().to( repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True) block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave( decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...], block_table_tensor[:self._num_decodes, ...],
repeats, dim=0) repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0) decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to( seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True) seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None: if self.spec_decode_block_table_tensor is not None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -56,6 +59,9 @@ class EagleProposer: ...@@ -56,6 +59,9 @@ class EagleProposer:
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
self.use_full_cuda_graph = (
self.use_cuda_graph
and vllm_config.compilation_config.full_cuda_graph)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes)) self.vllm_config.compilation_config.cudagraph_capture_sizes))
...@@ -71,6 +77,9 @@ class EagleProposer: ...@@ -71,6 +77,9 @@ class EagleProposer:
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
# attention metadata captured in full cudagraph mode
self.attn_metadata_cudagraph = None
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
...@@ -157,7 +166,7 @@ class EagleProposer: ...@@ -157,7 +166,7 @@ class EagleProposer:
# FIXME: need to consider multiple kv_cache_groups # FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build( attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata
) )
else: else:
raise ValueError(f"Unsupported method: {self.method}") raise ValueError(f"Unsupported method: {self.method}")
...@@ -176,6 +185,38 @@ class EagleProposer: ...@@ -176,6 +185,38 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if (self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
...@@ -190,6 +231,7 @@ class EagleProposer: ...@@ -190,6 +231,7 @@ class EagleProposer:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
...@@ -227,10 +269,10 @@ class EagleProposer: ...@@ -227,10 +269,10 @@ class EagleProposer:
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...] block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode( attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table, block_table_tensor=block_table,
seq_lens=(seq_lens + 1), seq_lens=seq_lens,
) )
for _ in range(self.num_speculative_tokens - 1): for i in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
...@@ -282,6 +324,43 @@ class EagleProposer: ...@@ -282,6 +324,43 @@ class EagleProposer:
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -307,6 +386,7 @@ class EagleProposer: ...@@ -307,6 +386,7 @@ class EagleProposer:
# [batch_size, num_speculative_tokens] # [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids return draft_token_ids
@staticmethod @staticmethod
...@@ -342,7 +422,7 @@ class EagleProposer: ...@@ -342,7 +422,7 @@ class EagleProposer:
) )
batch_size = num_rejected_tokens.shape[0] batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size,)]( prepare_eagle_input_kernel[(batch_size, )](
token_indices, token_indices,
cu_target_query_lens, cu_target_query_lens,
cu_num_tokens, cu_num_tokens,
...@@ -404,8 +484,13 @@ class EagleProposer: ...@@ -404,8 +484,13 @@ class EagleProposer:
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]]
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
self.model( self.model(
self.input_ids[:num_tokens], self.input_ids[:num_tokens],
......
...@@ -2083,7 +2083,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2083,7 +2083,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens, attn_metadata)
# This is necessary to avoid blocking DP. # This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real # For dummy runs, we typically skip EPLB since we don't have any real
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment