Commit a7378418 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Model][DSV4] Support base model (#41006)

parent fab1acce
......@@ -10,7 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_rank,
......@@ -65,6 +65,8 @@ from .utils import (
maybe_prefix,
)
_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8")
class DeepseekV4MLP(nn.Module):
def __init__(
......@@ -118,16 +120,59 @@ class DeepseekV4MLP(nn.Module):
class DeepseekV4FP8Config(Fp8Config):
"""FP8 config that routes MoE layers to MXFP4 quantization.
DeepSeek V4 checkpoints use FP8 for linear/attention layers but
MXFP4 for MoE expert weights. This config inherits standard FP8
behavior and overrides only the MoE dispatch.
"""FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch.
DeepSeek V4 checkpoints always use FP8 block quantization for
linear/attention layers. The MoE expert weights vary by checkpoint:
- ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts
with ue8m0 (e8m0fnu) FP8 linear scales.
- ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block
experts with float32 FP8 linear scales.
The dispatch and the linear scale dtype are both keyed off
``expert_dtype`` from the model's hf_config; missing values default
to ``"fp4"`` so existing FP4 checkpoints stay unchanged.
NOTE: ``expert_dtype`` is resolved lazily because this config is
constructed during VllmConfig setup, before ``set_current_vllm_config``
is active. Reading hf_config eagerly in ``__init__`` would always see
the default ``"fp4"`` and silently misroute Flash-Base checkpoints.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_scale_e8m0: bool = True
self._resolved_expert_dtype: str | None = None
# ``is_scale_e8m0`` is a property that resolves on first read,
# by which time the current vllm_config has been set.
@property
def expert_dtype(self) -> str:
if self._resolved_expert_dtype is None:
try:
hf_config = get_current_vllm_config().model_config.hf_config
except Exception:
# vllm_config not yet set; defer the decision until a
# later call lands inside set_current_vllm_config.
return "fp4"
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
raise ValueError(
f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; "
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
)
self._resolved_expert_dtype = expert_dtype
from vllm.logger import init_logger
init_logger(__name__).info_once(
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
)
return self._resolved_expert_dtype
@property
def is_scale_e8m0(self) -> bool:
# FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert
# checkpoints (Flash-Base) store them as float32.
return self.expert_dtype == "fp4"
@classmethod
def get_name(cls) -> QuantizationMethods:
......@@ -155,11 +200,14 @@ class DeepseekV4FP8Config(Fp8Config):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
if self.expert_dtype == "fp4":
return Mxfp4MoEMethod(layer.moe_config)
# expert_dtype == "fp8": fall through to Fp8Config which
# returns Fp8MoEMethod with block-wise float32 scales.
return super().get_quant_method(layer, prefix)
def is_mxfp4_quant(self, prefix, layer):
return isinstance(layer, FusedMoE)
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"
@triton.jit
......@@ -689,6 +737,12 @@ class DeepseekV4MoE(nn.Module):
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only."
)
if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4":
raise NotImplementedError(
"DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype="
f"{config.expert_dtype!r}. Drop --kernel-config moe_backend="
"deep_gemm_mega_moe for this checkpoint."
)
self.gate = GateLinear(
config.hidden_size,
......@@ -1410,10 +1464,24 @@ def hc_head(
return y.to(dtype)
class DeepseekV4ForCausalLM(nn.Module):
model_cls = DeepseekV4Model
hf_to_vllm_mapper = WeightsMapper(
def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
if expert_dtype == "fp4":
# MXFP4 experts use Mxfp4MoEMethod, which registers scales as
# ``w{1,2,3}_weight_scale`` (no _inv suffix). FP8 linear and
# shared experts use Fp8LinearMethod's block scales, which
# register as ``weight_scale_inv``.
scale_regex = {
re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale",
re.compile(r"\.scale$"): ".weight_scale_inv",
}
else:
# FP8 experts use Fp8MoEMethod (block_quant=True), which registers
# scales as ``w{13,2}_weight_scale_inv``. Map all ``.scale`` keys
# there.
scale_regex = {
re.compile(r"\.scale$"): ".weight_scale_inv",
}
return WeightsMapper(
orig_to_new_prefix={
"layers.": "model.layers.",
"embed.": "model.embed.",
......@@ -1421,12 +1489,7 @@ class DeepseekV4ForCausalLM(nn.Module):
"hc_head": "model.hc_head",
"mtp.": "model.mtp.",
},
orig_to_new_regex={
# Routed MoE expert scales: experts.N.wX.scale -> .weight_scale
re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale",
# Everything else (FP8 linear + shared experts): .scale -> .weight_scale_inv
re.compile(r"\.scale$"): ".weight_scale_inv",
},
orig_to_new_regex=scale_regex,
orig_to_new_suffix={
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
......@@ -1438,11 +1501,22 @@ class DeepseekV4ForCausalLM(nn.Module):
},
)
class DeepseekV4ForCausalLM(nn.Module):
model_cls = DeepseekV4Model
# Default mapper assumes the original FP4-expert checkpoint layout.
# Overridden per-instance in __init__ when expert_dtype != "fp4".
hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
expert_dtype = getattr(config, "expert_dtype", "fp4")
if expert_dtype != "fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype)
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
......
......@@ -48,9 +48,14 @@ from .utils import maybe_prefix
logger = init_logger(__name__)
# MoE expert scales are fused into per-layer w13/w2 tensors; other FP8 linear
# scales use `.weight_scale_inv`. Mirrors the regex in
# DeepseekV4ForCausalLM.hf_to_vllm_mapper.
# MoE expert scales are fused into per-layer w13/w2 tensors. The exact
# parameter suffix depends on which FusedMoE method handles the experts:
# - fp4 experts (Mxfp4MoEMethod) register ``w{1,2,3}_weight_scale``;
# - fp8 experts (Fp8MoEMethod with block_quant=True) register
# ``w{1,2,3}_weight_scale_inv``.
# Other FP8 linear scales (including shared experts) always use
# ``.weight_scale_inv``. Mirrors the per-instance mapper built by
# ``_make_deepseek_v4_weights_mapper`` in deepseek_v4.py.
_EXPERT_SCALE_RE = re.compile(r"\.experts\.\d+\.w[123]\.scale$")
......@@ -326,6 +331,15 @@ class DeepSeekV4MTP(nn.Module):
num_experts=self.config.n_routed_experts,
)
# FP8 experts register ``..._weight_scale_inv`` (block_quant) while
# FP4/MXFP4 experts register ``..._weight_scale``. Choose the suffix
# for the rename below based on the model's expert dtype.
expert_scale_suffix = (
".weight_scale"
if getattr(self.config, "expert_dtype", "fp4") == "fp4"
else ".weight_scale_inv"
)
for name, loaded_weight in weights:
mtp_layer_idx = _find_mtp_layer_idx(name)
# V4 checkpoints store MTP weights as `mtp.{i}.*`; remap to
......@@ -347,7 +361,7 @@ class DeepSeekV4MTP(nn.Module):
continue
if name.endswith(".scale"):
suffix = (
".weight_scale"
expert_scale_suffix
if _EXPERT_SCALE_RE.search(name)
else ".weight_scale_inv"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment