Unverified Commit 862dd76c authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)

parent fb4c9c3a
...@@ -98,6 +98,7 @@ class ModelConfig: ...@@ -98,6 +98,7 @@ class ModelConfig:
if ( if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures "DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
): ):
self.head_dim = 256 self.head_dim = 256
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only DeepSeek NextN Speculative Decoding."""
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
class DeepseekModelNextN(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.decoder = DeepseekV2DecoderLayer(
config, 0, quant_config=quant_config, is_nextn=True
)
self.shared_head = nn.Module()
self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
hidden_states = self.eh_proj(
torch.cat(
(
self.enorm(hidden_states),
self.hnorm(forward_batch.spec_info.hidden_states),
),
dim=-1,
)
)
residual = None
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual
)
if not forward_batch.forward_mode.is_idle():
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
return hidden_states
class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
self.config = config
self.quant_config = quant_config
self.model = DeepseekModelNextN(config, quant_config)
if global_server_args_dict["enable_dp_attention"]:
self.model.shared_head.head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.model.shared_head.head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.model.shared_head.head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
assert num_nextn_layers == self.config.num_hidden_layers
else:
raise ValueError("num_nextn_predict_layers is not in the config")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
)
nextn_layer_prefix = "model.layers.0"
nextn_spec_weight_names = [
"shared_head.head",
"shared_head.norm",
"eh_proj",
"embed_tokens",
"enorm",
"hnorm",
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if not name.startswith(nextn_layer_prefix):
continue
else:
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not global_server_args_dict["disable_mla"]:
self_attn = self.model.decoder.self_attn
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if is_hip_:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
input_scale=None,
)
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if is_hip_:
self_attn.w_scale *= 2.0
EntryClass = [DeepseekV3ForCausalLMNextN]
...@@ -519,6 +519,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -519,6 +519,8 @@ class DeepseekV2AttentionMLA(nn.Module):
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0 and forward_batch.extend_prefix_lens.sum() == 0
): ):
return self.forward_normal(positions, hidden_states, forward_batch) return self.forward_normal(positions, hidden_states, forward_batch)
...@@ -680,6 +682,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -680,6 +682,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
is_nextn: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -731,7 +734,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -731,7 +734,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
) )
if ( if is_nextn or (
config.n_routed_experts is not None config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0 and layer_id % config.moe_layer_freq == 0
......
...@@ -262,14 +262,17 @@ class ServerArgs: ...@@ -262,14 +262,17 @@ class ServerArgs:
) )
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "EAGLE": if (
self.speculative_algorithm == "EAGLE"
or self.speculative_algorithm == "NEXTN"
):
self.prefill_only_one_req = True self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True self.disable_cuda_graph_padding = True
self.disable_radix_cache = True self.disable_radix_cache = True
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
self.chunked_prefill_size = -1 self.chunked_prefill_size = -1
logger.info( logger.info(
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding." f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
) )
# GGUF # GGUF
...@@ -705,7 +708,7 @@ class ServerArgs: ...@@ -705,7 +708,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--speculative-algorithm", "--speculative-algorithm",
type=str, type=str,
choices=["EAGLE"], choices=["EAGLE", "NEXTN"],
help="Speculative algorithm.", help="Speculative algorithm.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk, fast_topk,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker): ...@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker):
# Parse arguments # Parse arguments
self.topk = server_args.speculative_eagle_topk self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.server_args = server_args self.server_args = server_args
# Share the embedding and lm_head # Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head() if not self.speculative_algorithm.is_nextn():
self.model_runner.model.set_embed_and_head(embed, head) embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
......
...@@ -5,18 +5,28 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -5,18 +5,28 @@ class SpeculativeAlgorithm(IntEnum):
NONE = auto() NONE = auto()
EAGLE = auto() EAGLE = auto()
# NEXTN spec decoding is for DeepSeek V3/R1
# currently it's implemented based on EAGLE
NEXTN = auto()
def is_none(self): def is_none(self):
return self == SpeculativeAlgorithm.NONE return self == SpeculativeAlgorithm.NONE
def is_eagle(self): def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN
def is_nextn(self):
return self == SpeculativeAlgorithm.NEXTN
@staticmethod @staticmethod
def from_string(name: str): def from_string(name: str):
name_map = { name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE": SpeculativeAlgorithm.EAGLE,
"NEXTN": SpeculativeAlgorithm.NEXTN,
None: SpeculativeAlgorithm.NONE, None: SpeculativeAlgorithm.NONE,
} }
if name is not None:
name = name.upper()
return name_map[name] return name_map[name]
......
"""
Export NextN layer for DeepSeek-V3/R1 model. The exported model can be used for speculative decoding.
Usage:
python3 export_deepseek_nextn.py --input-dir /path/to/DeepSeek-V3 --output-dir /path/to/DeepSeek-V3-NextN
"""
import argparse
import json
import os
import shutil
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import AutoConfig
def get_nexn_layer_id(config):
if not hasattr(config, "num_hidden_layers"):
raise ValueError("'num_hidden_layers' not found in model config.")
return config.num_hidden_layers
def update_and_save_config(config, output_dir):
new_config = config.to_dict()
new_config.update(
{
"num_hidden_layers": 0,
"architectures": ["DeepseekV3ForCausalLMNextN"],
}
)
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(new_config, f, indent=2, ensure_ascii=False, sort_keys=True)
def copy_non_safetensors_files(input_dir, output_dir):
for filename in os.listdir(input_dir):
src_file_path = os.path.join(input_dir, filename)
if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"):
dst_file_path = os.path.join(output_dir, filename)
shutil.copy2(src_file_path, dst_file_path)
print(f"All non-safetensors files have been copied to {output_dir}")
def export_nextn_layer_parameters(input_dir, output_dir, nexn_layer_id):
prefix = f"model.layers.{nexn_layer_id}"
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
params = {}
for filename in os.listdir(input_dir):
if not filename.endswith(".safetensors"):
continue
file_path = os.path.join(input_dir, filename)
print(f"Processing: {filename}")
try:
with safe_open(file_path, framework="pt") as f:
matching_keys = [k for k in f.keys() if k.startswith(prefix)]
if not matching_keys:
print(f" No parameters starting with '{prefix}' found")
continue
for key in matching_keys:
new_key = key.replace(prefix, "model.layers.0")
params[new_key] = f.get_tensor(key)
except Exception as e:
print(f" Error processing {filename}: {str(e)}")
if params:
print(f"Saving {len(params)} parameters to {output_path}")
save_file(params, output_path)
else:
print("No matching parameters found.")
# Update safetensors index
index_path = os.path.join(output_dir, "model.safetensors.index.json")
print(f"Updating safetensors index to {index_path}")
index_data = {"weight_map": {}}
for key in params:
index_data["weight_map"][key] = "nextn_layer_parameters.safetensors"
with open(index_path, "w") as f:
json.dump(index_data, f, indent=4)
print("All done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export NextN layer paramerters for DeepSeek-V3/R1"
)
parser.add_argument(
"--input-dir",
type=str,
required=True,
help="Input HF model directory.",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output nextn model directory.",
)
args = parser.parse_args()
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
nextn_layer_id = get_nexn_layer_id(config)
os.makedirs(args.output_dir, exist_ok=True)
copy_non_safetensors_files(args.input_dir, args.output_dir)
update_and_save_config(config, args.output_dir)
export_nextn_layer_parameters(args.input_dir, args.output_dir, nextn_layer_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment