Unverified Commit fb16fbaf authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix incorrect KV cache allocation for MTP models. (#8482)


Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
parent 0ce84c82
......@@ -261,6 +261,9 @@ class ModelConfig:
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_text_config.hidden_size
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.num_nextn_predict_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", None
)
self.vocab_size = self.hf_text_config.vocab_size
# Verify quantization
......
......@@ -285,11 +285,21 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(
self.model, "end_layer", self.model_config.num_hidden_layers
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
# determine the number of layers.
model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
model_num_layers = (
self.model_config.num_nextn_predict_layers
if self.is_draft_worker and model_has_mtp_layers
else self.model_config.num_hidden_layers
)
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
self.num_effective_layers = self.end_layer - self.start_layer
assert (not model_has_mtp_layers) or (
self.num_effective_layers == model_num_layers
), "PP is not compatible with MTP models."
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
......@@ -1178,11 +1188,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=(
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
), # PP is not compatible with mla backend
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
......@@ -1195,11 +1201,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=(
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
), # PP is not compatible with mla backend
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment