Unverified Commit f381cf23 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix broken MTP weight loading for FP8 KV Scales (#27227)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 5ff5d94e
...@@ -16,7 +16,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -16,7 +16,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -278,6 +281,10 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -278,6 +281,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# According to DeepSeek-V3 Technical Report, MTP modules # According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights. # shares embedding layer. We only load the first weights.
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment