Unverified Commit dd408ee4 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Auto set draft model path for MTP (#5793)

parent 9419e75d
......@@ -47,6 +47,7 @@ class ModelConfig:
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
) -> None:
self.model_path = model_path
......@@ -85,6 +86,12 @@ class ModelConfig:
else:
enable_multimodal = True
if (
is_draft_model
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
......
......@@ -71,6 +71,7 @@ class TpModelWorker:
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
is_draft_model=is_draft_worker,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
......
......@@ -692,9 +692,14 @@ class ModelRunner:
self.device, self.gpu_id, distributed=self.tp_size > 1
)
if self.use_mla_backend:
num_layers = (
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
)
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
* num_layers
* torch._utils._element_size(self.kv_cache_dtype)
)
else:
......@@ -809,7 +814,11 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
layer_num=(
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
),
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
......
......@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
assert num_nextn_layers == self.config.num_hidden_layers
else:
raise ValueError("num_nextn_predict_layers is not in the config")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion > 0:
logger.info(
f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
)
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
else:
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.0.mlp.shared_experts.{suffix}"
)
for num_repeat in range(self.n_share_experts_fusion):
weights_list.append(
(
f"model.layers.0."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
nextn_layer_prefix = "model.layers.0"
nextn_spec_weight_names = [
"shared_head.norm",
"eh_proj",
"enorm",
"hnorm",
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if not name.startswith(nextn_layer_prefix):
continue
# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Handle fused_qkv_a_proj
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
self_attn = self.model.decoder.self_attn
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
if _is_cuda:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
).T
else:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
input_scale=None,
)
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if w.dtype == torch.int8:
if hasattr(self.quant_config, "weight_block_size"):
# block-wise int8 need it
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
torch.bfloat16
)
else:
# channel-wise int8 need it
assert hasattr(self_attn.kv_b_proj, "weight_scale")
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
super().load_weights(weights, is_nextn=True)
EntryClass = [DeepseekV3ForCausalLMNextN]
......@@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch
)
def post_load_weights(self):
def post_load_weights(self, is_nextn=False):
# Perform post-processing after loading weights
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
layer_ids = (
range(self.config.num_hidden_layers)
if not is_nextn
else [self.config.num_hidden_layers]
)
for layer_id in layer_ids:
self_attn = (
self.model.layers[layer_id].self_attn
if not is_nextn
else self.model.decoder.self_attn
)
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
if _is_cuda:
......@@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_vc = w_vc.contiguous()
self_attn.use_deep_gemm_bmm = True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
# compatible with old design
nextn_layer_id = (
0
if self.config.num_hidden_layers == 1
else self.config.num_hidden_layers
)
else:
raise ValueError("num_nextn_predict_layers is not in the config")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
......@@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module):
"up_proj.weight_scale_inv",
]
names_to_remove = []
for moe_layer in tqdm(
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
),
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in tqdm(
moe_layers,
desc=f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE",
):
......@@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module):
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
"shared_head.norm",
"eh_proj",
"enorm",
"hnorm",
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# TODO(HandH1998): Modify it when nextn is supported.
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if (
len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers
):
continue
if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if (
len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers
):
continue
else:
if not name.startswith(nextn_layer_prefix):
continue
# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
......@@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
weight_loader(param, loaded_weight)
self.post_load_weights()
self.post_load_weights(is_nextn=is_nextn)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
......
......@@ -22,7 +22,7 @@ import random
import tempfile
from typing import List, Literal, Optional
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import (
configure_ipv6,
......@@ -333,6 +333,14 @@ class ServerArgs:
"eagle speculative decoding."
)
model_arch = get_model_arch(self)
# Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None and model_arch in [
"DeepseekV3ForCausalLM"
]:
self.speculative_draft_model_path = self.model_path
# Auto choose parameters
if self.speculative_num_steps is None:
assert (
......@@ -343,7 +351,7 @@ class ServerArgs:
self.speculative_num_steps,
self.speculative_eagle_topk,
self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self)
) = auto_choose_speculative_params(model_arch)
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
......@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
raise ValueError(self.help)
def auto_choose_speculative_params(self: ServerArgs):
def get_model_arch(args: ServerArgs):
hf_config = get_config(
args.model_path,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
model_override_args=json.loads(args.json_model_override_args),
)
return hf_config.architectures[0]
def auto_choose_speculative_params(arch: str):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise ValueError(f"{config_path} is not found.")
config = json.load(open(config_path))
arch = config.get("architectures", ["Unknown"])[0]
if arch in ["LlamaForCausalLM"]:
# The default value for llama
return (5, 4, 8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment