Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
......@@ -45,6 +45,7 @@ from vllm.v1.core.kv_cache_utils import (
get_kv_cache_configs,
get_request_block_hasher,
init_none_hash,
resolve_kv_cache_block_sizes,
)
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -137,10 +138,8 @@ class EngineCore:
logger.warning("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.enable_chunked_prefill = False
scheduler_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
* vllm_config.parallel_config.prefill_context_parallel_size
scheduler_block_size, hash_block_size = resolve_kv_cache_block_sizes(
kv_cache_config, vllm_config
)
self.scheduler: SchedulerInterface = Scheduler(
......@@ -150,6 +149,7 @@ class EngineCore:
include_finished_set=include_finished_set,
log_stats=self.log_stats,
block_size=scheduler_block_size,
hash_block_size=hash_block_size,
)
self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore
......@@ -207,7 +207,7 @@ class EngineCore:
init_none_hash(caching_hash_fn)
self.request_block_hasher = get_request_block_hasher(
scheduler_block_size, caching_hash_fn
hash_block_size, caching_hash_fn
)
self.step_fn = (
......
......@@ -4,6 +4,7 @@
from __future__ import annotations
import copy
from collections import Counter
from dataclasses import dataclass, fields, replace
from enum import IntEnum
from math import prod
......@@ -13,11 +14,11 @@ import torch
from typing_extensions import Self
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
logger = init_logger(__name__)
......@@ -95,6 +96,10 @@ class KVCacheSpec:
"""
raise NotImplementedError
@property
def storage_block_size(self) -> int:
return self.block_size
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
"""
The maximum possible memory usage of this KV cache in bytes.
......@@ -269,6 +274,15 @@ class FullAttentionSpec(AttentionSpec):
)
def _apply_alignment_padding(spec: MLAAttentionSpec | SlidingWindowMLASpec):
if spec.alignment is None:
return
actual_page_size = spec.real_page_size_bytes
padded_page_size = round_up(actual_page_size, spec.alignment)
if padded_page_size != actual_page_size:
object.__setattr__(spec, "page_size_padded", padded_page_size)
@dataclass(frozen=True, kw_only=True)
class TQFullAttentionSpec(FullAttentionSpec):
"""FullAttentionSpec with TQ-aware page size.
......@@ -299,15 +313,31 @@ class TQFullAttentionSpec(FullAttentionSpec):
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this
cache_dtype_str: str | None = None
# DeepseekV4 only fields. Non-DeepseekV4 MLA models leave these at defaults.
alignment: int | None = None # Default to None for no padding.
compress_ratio: int = 1 # Default to 1 for no compression.
model_version: str | None = None
def __post_init__(self):
super().__post_init__()
_apply_alignment_padding(self)
@property
def storage_block_size(self) -> int:
return self.block_size // self.compress_ratio
@property
def real_page_size_bytes(self) -> int:
if self.cache_dtype_str == "fp8_ds_mla":
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
# for details.
if self.model_version == "deepseek_v4":
# DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token.
# head_size stays semantic (512); bytes are determined here.
return self.storage_block_size * 584
# V3.2 main MLA: 656-byte custom layout (kv_lora_rank=512 +
# qk_rope_head_dim=64, head_size=576). See flashmla_sparse.py.
return self.block_size * 656
return (
self.block_size
self.storage_block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
......@@ -319,9 +349,15 @@ class MLAAttentionSpec(FullAttentionSpec):
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
compress_ratio_set = set(spec.compress_ratio for spec in specs)
model_version_set = set(spec.model_version for spec in specs)
assert (
len(cache_dtype_str_set) == 1
and len(compress_ratio_set) == 1
and len(model_version_set) == 1
), (
"All attention layers in the same KV cache group must use the same "
"quantization method."
"quantization method, compress ratio, and model version."
)
return cls(
block_size=specs[0].block_size,
......@@ -331,6 +367,8 @@ class MLAAttentionSpec(FullAttentionSpec):
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded,
cache_dtype_str=cache_dtype_str_set.pop(),
compress_ratio=compress_ratio_set.pop(),
model_version=model_version_set.pop(),
)
......@@ -393,6 +431,71 @@ class SlidingWindowSpec(AttentionSpec):
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass(frozen=True, kw_only=True)
class SlidingWindowMLASpec(SlidingWindowSpec):
"""Sliding window attention with MLA cache format."""
cache_dtype_str: str | None = None
# DeepseekV4-only: see MLAAttentionSpec.model_version.
alignment: int | None = None # Default to None for no padding.
compress_ratio: int = 1
model_version: str | None = None
def __post_init__(self):
_apply_alignment_padding(self)
@property
def storage_block_size(self) -> int:
return self.block_size // self.compress_ratio
@property
def real_page_size_bytes(self) -> int:
if self.model_version == "deepseek_v4":
# DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token.
return self.storage_block_size * 584
assert self.model_version is None, (
f"Unsupported model version: {self.model_version}"
)
return (
self.storage_block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, SlidingWindowMLASpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"SlidingWindowMLASpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
compress_ratio_set = set(spec.compress_ratio for spec in specs)
model_version_set = set(spec.model_version for spec in specs)
sliding_window_set = set(spec.sliding_window for spec in specs)
assert (
len(cache_dtype_str_set) == 1
and len(compress_ratio_set) == 1
and len(model_version_set) == 1
and len(sliding_window_set) == 1
), (
"All attention layers in the same KV cache group must use the same "
"quantization method, compress ratio, model version and sliding "
"window size."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
page_size_padded=specs[0].page_size_padded,
sliding_window=sliding_window_set.pop(),
cache_dtype_str=cache_dtype_str_set.pop(),
compress_ratio=compress_ratio_set.pop(),
model_version=model_version_set.pop(),
)
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
......@@ -527,7 +630,17 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
# Different block sizes, not uniform.
return False
one_spec = next(iter(kv_cache_specs.values()))
if isinstance(one_spec, FullAttentionSpec):
# NOTE: Check subclasses before parent classes since isinstance()
# returns True for subclasses.
if isinstance(one_spec, SlidingWindowMLASpec):
# SlidingWindowMLASpec is uniform if all specs are SlidingWindowMLASpec
# with the same sliding_window size.
return all(
isinstance(spec, SlidingWindowMLASpec)
and spec.sliding_window == one_spec.sliding_window
for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, FullAttentionSpec):
return all(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
)
......@@ -571,6 +684,21 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
else:
return None
# NOTE: below util functions are only used by DeepseekV4 for now.
def get_page_sizes(self) -> list[int]:
return list(set(spec.page_size_bytes for spec in self.kv_cache_specs.values()))
def get_num_layer_tuples(self) -> int:
return Counter(
spec.page_size_bytes for spec in self.kv_cache_specs.values()
).most_common(1)[0][1]
def max_memory_usage_pages(self, vllm_config: VllmConfig) -> int:
return max(
cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
for spec in self.kv_cache_specs.values()
)
@dataclass
class KVCacheTensor:
......@@ -593,6 +721,8 @@ class KVCacheGroupSpec:
layer_names: list[str]
# The KV cache spec of this manager layer
kv_cache_spec: KVCacheSpec
# Whether this group contains EAGLE/MTP draft attention layers.
is_eagle_group: bool = False
@dataclass
......
......@@ -84,6 +84,15 @@ class SpecDecodeBaseProposer:
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
# DeepSeek V4 MTP consumes the target's pre-hc_head residual stream,
# shape (T, hc_mult * hidden_size). Expand the hidden_states buffer
# so target_hidden_states fits; detect DeepseekV4 via draft hf_config.
draft_hf_config = self.draft_model_config.hf_config
if hasattr(draft_hf_config, "compress_ratios") and hasattr(
draft_hf_config, "hc_mult"
):
self.hidden_size = self.hidden_size * draft_hf_config.hc_mult
# Unifying eagle, draft model, and parallel drafting support.
# DFlash always uses parallel drafting (all tokens in one pass),
# but has an additional slot for the next_token_id (does not shift like EAGLE)
......@@ -1308,9 +1317,12 @@ class SpecDecodeBaseProposer:
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
self._draft_attn_layer_names = (
set(all_attn_layers.keys()) - target_attn_layer_names
)
# Filter to only layers that have KV cache specs.
self._draft_attn_layer_names = {
name
for name in (set(all_attn_layers.keys()) - target_attn_layer_names)
if all_attn_layers[name].get_kv_cache_spec(self.vllm_config) is not None
}
if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use
......@@ -1514,6 +1526,17 @@ class SpecDecodeBaseProposer:
"Shared target model lm_head with MTP shared_head.head."
)
if hasattr(target_language_model.model, "topk_indices_buffer"):
if hasattr(self.model.model, "topk_indices_buffer"):
del self.model.model.topk_indices_buffer
self.model.model.topk_indices_buffer = (
target_language_model.model.topk_indices_buffer
)
logger.info(
"Detected MTP model with topk_indices_buffer. "
"Sharing target model topk_indices_buffer with the draft model."
)
if self.use_local_argmax_reduction:
if not hasattr(self.model, "get_top_tokens"):
raise ValueError(
......
......@@ -9,6 +9,7 @@ import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -162,7 +163,7 @@ def _reshape_kv_cache(
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.storage_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype,
......@@ -183,8 +184,28 @@ def _reshape_kv_cache(
dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
if kv_cache_spec.page_size_padded is not None:
# Use strided view to handle page_size_bytes that
# include padding. This follows the same pattern as
# MambaSpec handling in gpu_model_runner.py.
# NOTE: This assumes kv_cache_shape[0] == num_blocks
# (i.e. the first physical dimension is the block
# index), which holds for MLA backends but NOT for
# standard attention backends whose shape starts with
# a K/V dimension of size 2.
dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
strides[inv_order[0]] = page_stride
kv_cache = torch.as_strided(
raw_tensor,
size=kv_cache_shape,
stride=tuple(strides),
)
else:
# No padding — safe to use a contiguous view.
kv_cache = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order)
return kv_caches
......@@ -230,6 +251,7 @@ def build_attn_metadata(
seq_lens_cpu_upper_bound: torch.Tensor | None = None,
dcp_local_seq_lens: torch.Tensor | None = None,
encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None,
positions: torch.Tensor | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
......@@ -256,6 +278,7 @@ def build_attn_metadata(
slot_mapping=slot_mapping,
causal=True,
dcp_local_seq_lens=dcp_local_seq_lens,
positions=positions,
)
if encoder_seq_lens and i in encoder_seq_lens:
encoder_seq_lens_gpu, encoder_seq_lens_cpu = encoder_seq_lens[i]
......
......@@ -486,11 +486,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device,
),
)
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The
# target returns a persistent buffer sized at max_num_batched_tokens;
# slice to the active token count that propose() expects.
spec_hidden_states = hidden_states
if hasattr(self.model, "get_mtp_target_hidden_states"):
pre_hc_hidden_states = self.model.get_mtp_target_hidden_states()
spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr]
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
last_hidden_states=hidden_states,
last_hidden_states=spec_hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
......@@ -808,7 +817,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
out=seq_lens_cpu_upper_bound_np[:num_reqs],
)
seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np)
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
......@@ -1233,11 +1241,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None:
assert self.sampler is not None
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The
# target returns a persistent buffer sized at max_num_batched_tokens;
# slice to the active token count that propose() expects.
spec_hidden_states = hidden_states
if hasattr(self.model, "get_mtp_target_hidden_states"):
pre_hc_hidden_states = self.model.get_mtp_target_hidden_states()
spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr]
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
spec_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
......
......@@ -193,5 +193,6 @@ class DefaultModelState(ModelState):
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
positions=input_batch.positions,
)
return attn_metadata
......@@ -53,6 +53,11 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
# Widen for HC-multiplexed residuals (e.g. DeepSeek V4 feeds the MTP
# draft the target's pre-hc_head (T, hc_mult * hidden_size) residual).
# Non-HC models default to hc_mult=1 and are unaffected.
hc_mult = getattr(self.draft_model_config.hf_config, "hc_mult", 1)
self.hidden_size = self.hidden_size * hc_mult
self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype
......
......@@ -49,4 +49,19 @@ def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Mod
del eagle_model.lm_head
eagle_model.lm_head = target_model.lm_head
# MTP models call compute_logits via shared_head.head (a
# ParallelLMHead inside each MTP layer), not self.model.lm_head.
# If the checkpoint omits a copy of the lm_head weights at the
# MTP layer path, shared_head.head stays uninitialised and
# produces zero/NaN logits. Share it explicitly from the target.
inner = getattr(eagle_model, "model", None)
layers = getattr(inner, "layers", None) if inner is not None else None
if layers is not None:
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
for layer in items:
sh = getattr(layer, "shared_head", None)
if sh is not None and hasattr(sh, "head"):
del sh.head
sh.head = target_model.lm_head
return eagle_model
......@@ -104,6 +104,7 @@ class RequestState:
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens_np[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
self.num_computed_tokens_np[req_idx] = num_computed_tokens
if num_computed_tokens > 0 and num_computed_tokens <= prefill_len:
# For PD disagg or resumed requests: set last_sampled to the last
......
......@@ -2184,6 +2184,7 @@ class GPUModelRunner(
slot_mapping=slot_mapping_gid_0,
causal=True,
is_prefilling=is_prefilling,
positions=self.positions[:num_tokens_padded],
)
if self.dcp_world_size > 1:
......@@ -4671,6 +4672,16 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count
)
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). Safe to
# rebind here: hidden_states was already consumed for sampling
# above and is not used again in this branch.
alt = getattr(
self.get_model(), "get_mtp_target_hidden_states", lambda: None
)()
if alt is not None:
hidden_states = alt
num_rejected_tokens_gpu = None
if spec_decode_metadata is None:
token_indices_to_sample = None
......@@ -6587,9 +6598,15 @@ class GPUModelRunner(
)
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
# For MLA with compression, storage_block_size != block_size
if kv_cache_spec.storage_block_size != kv_cache_spec.block_size:
shape_block_size = kv_cache_spec.storage_block_size
else:
shape_block_size = kernel_block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
shape_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
......@@ -6613,12 +6630,31 @@ class GPUModelRunner(
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = (
kv_cache_raw_tensors[layer_name]
.view(dtype)
.view(kv_cache_shape)
.permute(*inv_order)
)
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
if kv_cache_spec.page_size_padded is not None:
# Use strided view to handle page_size_bytes that
# include padding. This follows
# the same pattern as MambaSpec handling below.
# NOTE: This assumes kv_cache_shape[0] == num_blocks
# (i.e. the first physical dimension is the block
# index), which holds for MLA backends but NOT for
# standard attention backends whose shape starts with
# a K/V dimension of size 2.
dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
strides[inv_order[0]] = page_stride
kv_cache = torch.as_strided(
raw_tensor,
size=kv_cache_shape,
stride=tuple(strides),
)
else:
# No padding — safe to use a contiguous view.
kv_cache = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order)
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment