Unverified Commit 2628a69e authored by Jiayi Yao's avatar Jiayi Yao Committed by GitHub
Browse files

[V1] Support Deepseek MTP (#18435)


Signed-off-by: default avatarRui Qiao <ruisearch42@gmail.com>
Signed-off-by: default avatarYaoJiayi <120040070@link.cuhk.edu.cn>
Co-authored-by: default avatarRui Qiao <ruisearch42@gmail.com>
parent 371f7e4c
...@@ -2255,7 +2255,7 @@ class DeviceConfig: ...@@ -2255,7 +2255,7 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model"] "draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler", SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"] "typical_acceptance_sampler"]
...@@ -2519,6 +2519,15 @@ class SpeculativeConfig: ...@@ -2519,6 +2519,15 @@ class SpeculativeConfig:
elif (self.draft_model_config.hf_config.model_type == elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"): "mlp_speculator"):
self.method = "mlp_speculator" self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type ==
"deepseek_mtp"):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else: else:
self.method = "draft_model" self.method = "draft_model"
...@@ -2738,7 +2747,7 @@ class SpeculativeConfig: ...@@ -2738,7 +2747,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens return self.num_speculative_tokens
def use_eagle(self) -> bool: def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3") return self.method in ("eagle", "eagle3", "deepseek_mtp")
def __repr__(self) -> str: def __repr__(self) -> str:
method = self.method method = self.method
......
...@@ -1338,7 +1338,7 @@ class EngineArgs: ...@@ -1338,7 +1338,7 @@ class EngineArgs:
is_ngram_enabled = True is_ngram_enabled = True
elif speculative_method == "medusa": elif speculative_method == "medusa":
is_medusa_enabled = True is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3"): elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True is_eagle_enabled = True
else: else:
speculative_model = self.speculative_config.get("model") speculative_model = self.speculative_config.get("model")
......
...@@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors ...@@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors
from .deepseek_v2 import (DeepseekV2DecoderLayer, from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name) get_spec_layer_idx_from_weight_name)
from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
...@@ -145,7 +146,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -145,7 +146,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return logits return logits
class DeepSeekMTP(nn.Module): class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -10,9 +10,10 @@ from vllm.forward_context import set_forward_context ...@@ -10,9 +10,10 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata FlashAttentionMetadata)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -25,12 +26,15 @@ class EagleProposer: ...@@ -25,12 +26,15 @@ class EagleProposer:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
runner=None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.runner = runner
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
...@@ -106,24 +110,46 @@ class EagleProposer: ...@@ -106,24 +110,46 @@ class EagleProposer:
# FA requires seq_len to have dtype int32. # FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int() seq_lens = (target_positions[last_token_indices] + 1).int()
# FIXME(woosuk): The below two ops cause synchronization. Optimize. if self.method in ["eagle", "eagle3"]:
max_seq_len = seq_lens.max().item() # FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() max_seq_len = seq_lens.max().item()
attn_metadata = FlashAttentionMetadata( max_num_tokens = (cu_num_tokens[1:] -
num_actual_tokens=num_tokens, cu_num_tokens[:-1]).max().item()
max_query_len=max_num_tokens, attn_metadata = FlashAttentionMetadata(
query_start_loc=cu_num_tokens, num_actual_tokens=num_tokens,
max_seq_len=max_seq_len, max_query_len=max_num_tokens,
seq_lens=seq_lens, query_start_loc=cu_num_tokens,
block_table=block_table, max_seq_len=max_seq_len,
slot_mapping=target_slot_mapping, seq_lens=seq_lens,
# TODO(woosuk): Support cascade attention. block_table=block_table,
use_cascade=False, slot_mapping=target_slot_mapping,
common_prefix_len=0, # TODO(woosuk): Support cascade attention.
cu_prefix_query_lens=None, use_cascade=False,
prefix_kv_lens=None, common_prefix_len=0,
suffix_kv_lens=None, cu_prefix_query_lens=None,
) prefix_kv_lens=None,
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise ValueError(f"Unsupported method: {self.method}")
if self.use_cuda_graph and \ if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
...@@ -136,11 +162,15 @@ class EagleProposer: ...@@ -136,11 +162,15 @@ class EagleProposer:
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model( ret_hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens], self.hidden_states[:num_input_tokens],
) )
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
...@@ -150,6 +180,10 @@ class EagleProposer: ...@@ -150,6 +180,10 @@ class EagleProposer:
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
...@@ -215,9 +249,9 @@ class EagleProposer: ...@@ -215,9 +249,9 @@ class EagleProposer:
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size], self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size], self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size], self.hidden_states[:input_batch_size],
) )
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
...@@ -268,7 +302,7 @@ class EagleProposer: ...@@ -268,7 +302,7 @@ class EagleProposer:
batch_size = num_rejected_tokens.shape[0] batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
prepare_input_kernel[(batch_size, )]( prepare_eagle_input_kernel[(batch_size, )](
token_indices, token_indices,
cu_target_query_lens, cu_target_query_lens,
cu_num_tokens, cu_num_tokens,
...@@ -320,9 +354,9 @@ class EagleProposer: ...@@ -320,9 +354,9 @@ class EagleProposer:
with set_forward_context(None, self.vllm_config, with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
self.model( self.model(
input_ids=self.input_ids[:num_tokens], self.input_ids[:num_tokens],
positions=self.positions[:num_tokens], self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens], self.hidden_states[:num_tokens],
) )
...@@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token( ...@@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token(
next_token_ids, next_token_ids,
) )
return next_token_ids, probs return next_token_ids, probs
@triton.jit
def prepare_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
...@@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: ...@@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
return False return False
return True return True
@triton.jit
def prepare_eagle_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
...@@ -151,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -151,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config) self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.drafter = EagleProposer(self.vllm_config, self.device,
self.device) # type: ignore self) # type: ignore
if self.speculative_config.method == "eagle3": if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa": elif self.speculative_config.method == "medusa":
...@@ -1361,6 +1365,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1361,6 +1365,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device) device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens] target_token_ids = self.input_ids[:num_scheduled_tokens]
...@@ -1406,7 +1416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1406,7 +1416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_slot_mapping=target_slot_mapping, target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens, cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_table, block_table=block_table,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
...@@ -1723,8 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1723,8 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
hidden_states = outputs hidden_states = outputs
if self.use_spec_decode and \ if self.use_spec_decode and self.speculative_config.use_eagle():
self.speculative_config.method in ('eagle', 'eagle3'):
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment