"lib/llm/src/vscode:/vscode.git/clone" did not exist on "ffccc72268838c1d5acdd73fbde8570358f30c90"
Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
...@@ -45,6 +45,7 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -45,6 +45,7 @@ from vllm.v1.core.kv_cache_utils import (
get_kv_cache_configs, get_kv_cache_configs,
get_request_block_hasher, get_request_block_hasher,
init_none_hash, init_none_hash,
resolve_kv_cache_block_sizes,
) )
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -137,10 +138,8 @@ class EngineCore: ...@@ -137,10 +138,8 @@ class EngineCore:
logger.warning("Disabling chunked prefill for model without KVCache") logger.warning("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.enable_chunked_prefill = False
scheduler_block_size = ( scheduler_block_size, hash_block_size = resolve_kv_cache_block_sizes(
vllm_config.cache_config.block_size kv_cache_config, vllm_config
* vllm_config.parallel_config.decode_context_parallel_size
* vllm_config.parallel_config.prefill_context_parallel_size
) )
self.scheduler: SchedulerInterface = Scheduler( self.scheduler: SchedulerInterface = Scheduler(
...@@ -150,6 +149,7 @@ class EngineCore: ...@@ -150,6 +149,7 @@ class EngineCore:
include_finished_set=include_finished_set, include_finished_set=include_finished_set,
log_stats=self.log_stats, log_stats=self.log_stats,
block_size=scheduler_block_size, block_size=scheduler_block_size,
hash_block_size=hash_block_size,
) )
self.use_spec_decode = vllm_config.speculative_config is not None self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore if self.scheduler.connector is not None: # type: ignore
...@@ -207,7 +207,7 @@ class EngineCore: ...@@ -207,7 +207,7 @@ class EngineCore:
init_none_hash(caching_hash_fn) init_none_hash(caching_hash_fn)
self.request_block_hasher = get_request_block_hasher( self.request_block_hasher = get_request_block_hasher(
scheduler_block_size, caching_hash_fn hash_block_size, caching_hash_fn
) )
self.step_fn = ( self.step_fn = (
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
from collections import Counter
from dataclasses import dataclass, fields, replace from dataclasses import dataclass, fields, replace
from enum import IntEnum from enum import IntEnum
from math import prod from math import prod
...@@ -13,11 +14,11 @@ import torch ...@@ -13,11 +14,11 @@ import torch
from typing_extensions import Self from typing_extensions import Self
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -95,6 +96,10 @@ class KVCacheSpec: ...@@ -95,6 +96,10 @@ class KVCacheSpec:
""" """
raise NotImplementedError raise NotImplementedError
@property
def storage_block_size(self) -> int:
return self.block_size
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
""" """
The maximum possible memory usage of this KV cache in bytes. The maximum possible memory usage of this KV cache in bytes.
...@@ -269,6 +274,15 @@ class FullAttentionSpec(AttentionSpec): ...@@ -269,6 +274,15 @@ class FullAttentionSpec(AttentionSpec):
) )
def _apply_alignment_padding(spec: MLAAttentionSpec | SlidingWindowMLASpec):
if spec.alignment is None:
return
actual_page_size = spec.real_page_size_bytes
padded_page_size = round_up(actual_page_size, spec.alignment)
if padded_page_size != actual_page_size:
object.__setattr__(spec, "page_size_padded", padded_page_size)
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class TQFullAttentionSpec(FullAttentionSpec): class TQFullAttentionSpec(FullAttentionSpec):
"""FullAttentionSpec with TQ-aware page size. """FullAttentionSpec with TQ-aware page size.
...@@ -299,15 +313,31 @@ class TQFullAttentionSpec(FullAttentionSpec): ...@@ -299,15 +313,31 @@ class TQFullAttentionSpec(FullAttentionSpec):
class MLAAttentionSpec(FullAttentionSpec): class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this # TODO(Lucas/Chen): less hacky way to do this
cache_dtype_str: str | None = None cache_dtype_str: str | None = None
# DeepseekV4 only fields. Non-DeepseekV4 MLA models leave these at defaults.
alignment: int | None = None # Default to None for no padding.
compress_ratio: int = 1 # Default to 1 for no compression.
model_version: str | None = None
def __post_init__(self):
super().__post_init__()
_apply_alignment_padding(self)
@property
def storage_block_size(self) -> int:
return self.block_size // self.compress_ratio
@property @property
def real_page_size_bytes(self) -> int: def real_page_size_bytes(self) -> int:
if self.cache_dtype_str == "fp8_ds_mla": if self.cache_dtype_str == "fp8_ds_mla":
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py` if self.model_version == "deepseek_v4":
# for details. # DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token.
# head_size stays semantic (512); bytes are determined here.
return self.storage_block_size * 584
# V3.2 main MLA: 656-byte custom layout (kv_lora_rank=512 +
# qk_rope_head_dim=64, head_size=576). See flashmla_sparse.py.
return self.block_size * 656 return self.block_size * 656
return ( return (
self.block_size self.storage_block_size
* self.num_kv_heads * self.num_kv_heads
* self.head_size * self.head_size
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
...@@ -319,9 +349,15 @@ class MLAAttentionSpec(FullAttentionSpec): ...@@ -319,9 +349,15 @@ class MLAAttentionSpec(FullAttentionSpec):
"All attention layers in the same KV cache group must be MLAAttentionSpec." "All attention layers in the same KV cache group must be MLAAttentionSpec."
) )
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, ( compress_ratio_set = set(spec.compress_ratio for spec in specs)
model_version_set = set(spec.model_version for spec in specs)
assert (
len(cache_dtype_str_set) == 1
and len(compress_ratio_set) == 1
and len(model_version_set) == 1
), (
"All attention layers in the same KV cache group must use the same " "All attention layers in the same KV cache group must use the same "
"quantization method." "quantization method, compress ratio, and model version."
) )
return cls( return cls(
block_size=specs[0].block_size, block_size=specs[0].block_size,
...@@ -331,6 +367,8 @@ class MLAAttentionSpec(FullAttentionSpec): ...@@ -331,6 +367,8 @@ class MLAAttentionSpec(FullAttentionSpec):
kv_quant_mode=specs[0].kv_quant_mode, kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded, page_size_padded=specs[0].page_size_padded,
cache_dtype_str=cache_dtype_str_set.pop(), cache_dtype_str=cache_dtype_str_set.pop(),
compress_ratio=compress_ratio_set.pop(),
model_version=model_version_set.pop(),
) )
...@@ -393,6 +431,71 @@ class SlidingWindowSpec(AttentionSpec): ...@@ -393,6 +431,71 @@ class SlidingWindowSpec(AttentionSpec):
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass(frozen=True, kw_only=True)
class SlidingWindowMLASpec(SlidingWindowSpec):
"""Sliding window attention with MLA cache format."""
cache_dtype_str: str | None = None
# DeepseekV4-only: see MLAAttentionSpec.model_version.
alignment: int | None = None # Default to None for no padding.
compress_ratio: int = 1
model_version: str | None = None
def __post_init__(self):
_apply_alignment_padding(self)
@property
def storage_block_size(self) -> int:
return self.block_size // self.compress_ratio
@property
def real_page_size_bytes(self) -> int:
if self.model_version == "deepseek_v4":
# DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token.
return self.storage_block_size * 584
assert self.model_version is None, (
f"Unsupported model version: {self.model_version}"
)
return (
self.storage_block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, SlidingWindowMLASpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"SlidingWindowMLASpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
compress_ratio_set = set(spec.compress_ratio for spec in specs)
model_version_set = set(spec.model_version for spec in specs)
sliding_window_set = set(spec.sliding_window for spec in specs)
assert (
len(cache_dtype_str_set) == 1
and len(compress_ratio_set) == 1
and len(model_version_set) == 1
and len(sliding_window_set) == 1
), (
"All attention layers in the same KV cache group must use the same "
"quantization method, compress ratio, model version and sliding "
"window size."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
page_size_padded=specs[0].page_size_padded,
sliding_window=sliding_window_set.pop(),
cache_dtype_str=cache_dtype_str_set.pop(),
compress_ratio=compress_ratio_set.pop(),
model_version=model_version_set.pop(),
)
@dataclass(frozen=True) @dataclass(frozen=True)
class MambaSpec(KVCacheSpec): class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...] shapes: tuple[tuple[int, ...], ...]
...@@ -527,7 +630,17 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): ...@@ -527,7 +630,17 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
# Different block sizes, not uniform. # Different block sizes, not uniform.
return False return False
one_spec = next(iter(kv_cache_specs.values())) one_spec = next(iter(kv_cache_specs.values()))
if isinstance(one_spec, FullAttentionSpec): # NOTE: Check subclasses before parent classes since isinstance()
# returns True for subclasses.
if isinstance(one_spec, SlidingWindowMLASpec):
# SlidingWindowMLASpec is uniform if all specs are SlidingWindowMLASpec
# with the same sliding_window size.
return all(
isinstance(spec, SlidingWindowMLASpec)
and spec.sliding_window == one_spec.sliding_window
for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, FullAttentionSpec):
return all( return all(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values() isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
) )
...@@ -571,6 +684,21 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): ...@@ -571,6 +684,21 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
else: else:
return None return None
# NOTE: below util functions are only used by DeepseekV4 for now.
def get_page_sizes(self) -> list[int]:
return list(set(spec.page_size_bytes for spec in self.kv_cache_specs.values()))
def get_num_layer_tuples(self) -> int:
return Counter(
spec.page_size_bytes for spec in self.kv_cache_specs.values()
).most_common(1)[0][1]
def max_memory_usage_pages(self, vllm_config: VllmConfig) -> int:
return max(
cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
for spec in self.kv_cache_specs.values()
)
@dataclass @dataclass
class KVCacheTensor: class KVCacheTensor:
...@@ -593,6 +721,8 @@ class KVCacheGroupSpec: ...@@ -593,6 +721,8 @@ class KVCacheGroupSpec:
layer_names: list[str] layer_names: list[str]
# The KV cache spec of this manager layer # The KV cache spec of this manager layer
kv_cache_spec: KVCacheSpec kv_cache_spec: KVCacheSpec
# Whether this group contains EAGLE/MTP draft attention layers.
is_eagle_group: bool = False
@dataclass @dataclass
......
...@@ -84,6 +84,15 @@ class SpecDecodeBaseProposer: ...@@ -84,6 +84,15 @@ class SpecDecodeBaseProposer:
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
# DeepSeek V4 MTP consumes the target's pre-hc_head residual stream,
# shape (T, hc_mult * hidden_size). Expand the hidden_states buffer
# so target_hidden_states fits; detect DeepseekV4 via draft hf_config.
draft_hf_config = self.draft_model_config.hf_config
if hasattr(draft_hf_config, "compress_ratios") and hasattr(
draft_hf_config, "hc_mult"
):
self.hidden_size = self.hidden_size * draft_hf_config.hc_mult
# Unifying eagle, draft model, and parallel drafting support. # Unifying eagle, draft model, and parallel drafting support.
# DFlash always uses parallel drafting (all tokens in one pass), # DFlash always uses parallel drafting (all tokens in one pass),
# but has an additional slot for the next_token_id (does not shift like EAGLE) # but has an additional slot for the next_token_id (does not shift like EAGLE)
...@@ -1308,9 +1317,12 @@ class SpecDecodeBaseProposer: ...@@ -1308,9 +1317,12 @@ class SpecDecodeBaseProposer:
self.vllm_config, self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract] AttentionLayerBase, # type: ignore[type-abstract]
) )
self._draft_attn_layer_names = ( # Filter to only layers that have KV cache specs.
set(all_attn_layers.keys()) - target_attn_layer_names self._draft_attn_layer_names = {
) name
for name in (set(all_attn_layers.keys()) - target_attn_layer_names)
if all_attn_layers[name].get_kv_cache_spec(self.vllm_config) is not None
}
if self.supports_mm_inputs: if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use # Even if the target model is multimodal, we can also use
...@@ -1514,6 +1526,17 @@ class SpecDecodeBaseProposer: ...@@ -1514,6 +1526,17 @@ class SpecDecodeBaseProposer:
"Shared target model lm_head with MTP shared_head.head." "Shared target model lm_head with MTP shared_head.head."
) )
if hasattr(target_language_model.model, "topk_indices_buffer"):
if hasattr(self.model.model, "topk_indices_buffer"):
del self.model.model.topk_indices_buffer
self.model.model.topk_indices_buffer = (
target_language_model.model.topk_indices_buffer
)
logger.info(
"Detected MTP model with topk_indices_buffer. "
"Sharing target model topk_indices_buffer with the draft model."
)
if self.use_local_argmax_reduction: if self.use_local_argmax_reduction:
if not hasattr(self.model, "get_top_tokens"): if not hasattr(self.model, "get_top_tokens"):
raise ValueError( raise ValueError(
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -162,7 +163,7 @@ def _reshape_kv_cache( ...@@ -162,7 +163,7 @@ def _reshape_kv_cache(
attn_backend = attn_backends[layer_name] attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape( kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, num_blocks,
kv_cache_spec.block_size, kv_cache_spec.storage_block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size, kv_cache_spec.head_size,
cache_dtype, cache_dtype,
...@@ -183,8 +184,28 @@ def _reshape_kv_cache( ...@@ -183,8 +184,28 @@ def _reshape_kv_cache(
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype) raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape) if kv_cache_spec.page_size_padded is not None:
kv_caches[layer_name] = raw_tensor.permute(*inv_order) # Use strided view to handle page_size_bytes that
# include padding. This follows the same pattern as
# MambaSpec handling in gpu_model_runner.py.
# NOTE: This assumes kv_cache_shape[0] == num_blocks
# (i.e. the first physical dimension is the block
# index), which holds for MLA backends but NOT for
# standard attention backends whose shape starts with
# a K/V dimension of size 2.
dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
strides[inv_order[0]] = page_stride
kv_cache = torch.as_strided(
raw_tensor,
size=kv_cache_shape,
stride=tuple(strides),
)
else:
# No padding — safe to use a contiguous view.
kv_cache = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order)
return kv_caches return kv_caches
...@@ -230,6 +251,7 @@ def build_attn_metadata( ...@@ -230,6 +251,7 @@ def build_attn_metadata(
seq_lens_cpu_upper_bound: torch.Tensor | None = None, seq_lens_cpu_upper_bound: torch.Tensor | None = None,
dcp_local_seq_lens: torch.Tensor | None = None, dcp_local_seq_lens: torch.Tensor | None = None,
encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None, encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None,
positions: torch.Tensor | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs] seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None: if dcp_local_seq_lens is not None:
...@@ -256,6 +278,7 @@ def build_attn_metadata( ...@@ -256,6 +278,7 @@ def build_attn_metadata(
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
causal=True, causal=True,
dcp_local_seq_lens=dcp_local_seq_lens, dcp_local_seq_lens=dcp_local_seq_lens,
positions=positions,
) )
if encoder_seq_lens and i in encoder_seq_lens: if encoder_seq_lens and i in encoder_seq_lens:
encoder_seq_lens_gpu, encoder_seq_lens_cpu = encoder_seq_lens[i] encoder_seq_lens_gpu, encoder_seq_lens_cpu = encoder_seq_lens[i]
......
...@@ -486,11 +486,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -486,11 +486,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device, device=self.device,
), ),
) )
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The
# target returns a persistent buffer sized at max_num_batched_tokens;
# slice to the active token count that propose() expects.
spec_hidden_states = hidden_states
if hasattr(self.model, "get_mtp_target_hidden_states"):
pre_hc_hidden_states = self.model.get_mtp_target_hidden_states()
spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr]
self.speculator.propose( self.speculator.propose(
input_batch=input_batch, input_batch=input_batch,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer, slot_mappings=slot_mappings_by_layer,
last_hidden_states=hidden_states, last_hidden_states=spec_hidden_states,
aux_hidden_states=aux_hidden_states, aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones( num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device input_batch.num_reqs, dtype=torch.int32, device=self.device
...@@ -808,7 +817,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -808,7 +817,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
out=seq_lens_cpu_upper_bound_np[:num_reqs], out=seq_lens_cpu_upper_bound_np[:num_reqs],
) )
seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np) seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np)
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
...@@ -1233,11 +1241,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1233,11 +1241,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None: if self.speculator is not None:
assert self.sampler is not None assert self.sampler is not None
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The
# target returns a persistent buffer sized at max_num_batched_tokens;
# slice to the active token count that propose() expects.
spec_hidden_states = hidden_states
if hasattr(self.model, "get_mtp_target_hidden_states"):
pre_hc_hidden_states = self.model.get_mtp_target_hidden_states()
spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr]
draft_tokens = self.speculator.propose( draft_tokens = self.speculator.propose(
input_batch, input_batch,
attn_metadata, attn_metadata,
slot_mappings_by_layer, slot_mappings_by_layer,
hidden_states, spec_hidden_states,
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
num_rejected, num_rejected,
......
...@@ -193,5 +193,6 @@ class DefaultModelState(ModelState): ...@@ -193,5 +193,6 @@ class DefaultModelState(ModelState):
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens, dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
positions=input_batch.positions,
) )
return attn_metadata return attn_metadata
...@@ -53,6 +53,11 @@ class EagleSpeculator: ...@@ -53,6 +53,11 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's # the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B). # hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
# Widen for HC-multiplexed residuals (e.g. DeepSeek V4 feeds the MTP
# draft the target's pre-hc_head (T, hc_mult * hidden_size) residual).
# Non-HC models default to hc_mult=1 and are unaffected.
hc_mult = getattr(self.draft_model_config.hf_config, "hc_mult", 1)
self.hidden_size = self.hidden_size * hc_mult
self.vocab_size = self.draft_model_config.get_vocab_size() self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
......
...@@ -49,4 +49,19 @@ def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Mod ...@@ -49,4 +49,19 @@ def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Mod
del eagle_model.lm_head del eagle_model.lm_head
eagle_model.lm_head = target_model.lm_head eagle_model.lm_head = target_model.lm_head
# MTP models call compute_logits via shared_head.head (a
# ParallelLMHead inside each MTP layer), not self.model.lm_head.
# If the checkpoint omits a copy of the lm_head weights at the
# MTP layer path, shared_head.head stays uninitialised and
# produces zero/NaN logits. Share it explicitly from the target.
inner = getattr(eagle_model, "model", None)
layers = getattr(inner, "layers", None) if inner is not None else None
if layers is not None:
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
for layer in items:
sh = getattr(layer, "shared_head", None)
if sh is not None and hasattr(sh, "head"):
del sh.head
sh.head = target_model.lm_head
return eagle_model return eagle_model
...@@ -104,6 +104,7 @@ class RequestState: ...@@ -104,6 +104,7 @@ class RequestState:
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens_np[req_idx] = num_computed_tokens self.num_computed_tokens_np[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens) self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
self.num_computed_tokens_np[req_idx] = num_computed_tokens
if num_computed_tokens > 0 and num_computed_tokens <= prefill_len: if num_computed_tokens > 0 and num_computed_tokens <= prefill_len:
# For PD disagg or resumed requests: set last_sampled to the last # For PD disagg or resumed requests: set last_sampled to the last
......
...@@ -2184,6 +2184,7 @@ class GPUModelRunner( ...@@ -2184,6 +2184,7 @@ class GPUModelRunner(
slot_mapping=slot_mapping_gid_0, slot_mapping=slot_mapping_gid_0,
causal=True, causal=True,
is_prefilling=is_prefilling, is_prefilling=is_prefilling,
positions=self.positions[:num_tokens_padded],
) )
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
...@@ -4671,6 +4672,16 @@ class GPUModelRunner( ...@@ -4671,6 +4672,16 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). Safe to
# rebind here: hidden_states was already consumed for sampling
# above and is not used again in this branch.
alt = getattr(
self.get_model(), "get_mtp_target_hidden_states", lambda: None
)()
if alt is not None:
hidden_states = alt
num_rejected_tokens_gpu = None num_rejected_tokens_gpu = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
token_indices_to_sample = None token_indices_to_sample = None
...@@ -6587,9 +6598,15 @@ class GPUModelRunner( ...@@ -6587,9 +6598,15 @@ class GPUModelRunner(
) )
kernel_num_blocks = num_blocks * num_blocks_per_kv_block kernel_num_blocks = num_blocks * num_blocks_per_kv_block
# For MLA with compression, storage_block_size != block_size
if kv_cache_spec.storage_block_size != kv_cache_spec.block_size:
shape_block_size = kv_cache_spec.storage_block_size
else:
shape_block_size = kernel_block_size
kv_cache_shape = attn_backend.get_kv_cache_shape( kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks, kernel_num_blocks,
kernel_block_size, shape_block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size, kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype, cache_dtype_str=self.cache_config.cache_dtype,
...@@ -6613,12 +6630,31 @@ class GPUModelRunner( ...@@ -6613,12 +6630,31 @@ class GPUModelRunner(
kv_cache_stride_order.index(i) kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order)) for i in range(len(kv_cache_stride_order))
] ]
kv_caches[layer_name] = (
kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
.view(dtype) if kv_cache_spec.page_size_padded is not None:
.view(kv_cache_shape) # Use strided view to handle page_size_bytes that
.permute(*inv_order) # include padding. This follows
) # the same pattern as MambaSpec handling below.
# NOTE: This assumes kv_cache_shape[0] == num_blocks
# (i.e. the first physical dimension is the block
# index), which holds for MLA backends but NOT for
# standard attention backends whose shape starts with
# a K/V dimension of size 2.
dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
strides[inv_order[0]] = page_stride
kv_cache = torch.as_strided(
raw_tensor,
size=kv_cache_shape,
stride=tuple(strides),
)
else:
# No padding — safe to use a contiguous view.
kv_cache = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order)
elif isinstance(kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment