Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
<|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗?
CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
## Tools
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
<|DSML|tool_calls>
<|DSML|invoke name="$TOOL_NAME">
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
...
</|DSML|invoke>
<|DSML|invoke name="$TOOL_NAME2">
...
</|DSML|invoke>
</|DSML|tool_calls>
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
Otherwise, output directly after </think> with tool calls or final response.
### Available Tool Schemas
{"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
{"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
{"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
<|Assistant|><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
<|DSML|tool_calls>
<|DSML|invoke name="search">
<|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</|DSML|parameter>
</|DSML|invoke>
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>[0]</tool_result><|Assistant|><think>现在开始组织回答。</think>请及时就医。<|end▁of▁sentence|>
\ No newline at end of file
<|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。
今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
为了让你快速了解,我整理了一个简要的表格:
| 维度 | 热海大滚锅的地位与特点 |
| :--- | :--- |
| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
### 💡 游玩攻略与温馨提示
如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
- **门票与开放时间**:
- **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
- **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
- **特色体验**:
- **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
- **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
- **注意事项**:
- **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
- **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|></think><|action|>Search<|end▁of▁sentence|>
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from pathlib import Path
from types import SimpleNamespace
import pytest
from vllm.entrypoints.chat_utils import parse_chat_messages
from vllm.renderers.registry import RENDERER_REGISTRY
from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer
from vllm.tokenizers.registry import TokenizerRegistry
FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4"
class FakeHfTokenizer:
vocab_size = 100
def get_added_vocab(self) -> dict[str, int]:
return {"</think>": 100}
def encode(
self,
text: str,
add_special_tokens: bool = False,
**kwargs,
) -> list[int]:
self.last_encode = (text, add_special_tokens, kwargs)
return [len(text)]
def _tokenizer():
return get_deepseek_v4_tokenizer(FakeHfTokenizer())
def _model_config():
return SimpleNamespace(
multimodal_config=None,
allowed_local_media_path="",
allowed_media_domains=None,
)
def _load_reference_case(case_id: int):
data = json.loads((FIXTURES_DIR / f"test_input_{case_id}.json").read_text())
if isinstance(data, dict):
return data["messages"], data.get("tools")
return data, None
def _render_reference_case(case_id: int, **kwargs):
messages, tools = _load_reference_case(case_id)
conversation, _, _ = parse_chat_messages(
messages,
_model_config(),
content_format="string",
)
return _tokenizer().apply_chat_template(
conversation=conversation,
messages=messages,
tools=tools,
tokenize=False,
**kwargs,
)
def test_deepseek_v4_tokenizer_registered():
assert TokenizerRegistry.load_tokenizer_cls("deepseek_v4").__name__ == (
"DeepseekV4Tokenizer"
)
assert RENDERER_REGISTRY.load_renderer_cls("deepseek_v4").__name__ == (
"DeepseekV4Renderer"
)
def test_deepseek_v4_defaults_to_chat_mode():
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
)
assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|></think>")
@pytest.mark.parametrize("kwargs", [{"thinking": True}, {"enable_thinking": True}])
def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs):
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
**kwargs,
)
assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|><think>")
def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools():
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
]
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Weather?"}],
tools=tools,
tokenize=False,
)
assert "## Tools" in prompt
assert "<|DSML|tool_calls>" in prompt
assert "</|DSML|tool_calls>" in prompt
assert "function_calls" not in prompt
assert '"name": "get_weather"' in prompt
assert prompt.endswith("<|User|>Weather?<|Assistant|></think>")
def test_deepseek_v4_renders_parsed_history_tool_arguments():
messages = [
{"role": "user", "content": "List the repo"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "str_replace_editor",
"arguments": '{"command": "view", "path": "/testbed"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "file list",
},
]
tools = [
{
"type": "function",
"function": {
"name": "str_replace_editor",
"description": "Edit files",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string"},
"path": {"type": "string"},
},
"required": ["command", "path"],
},
},
}
]
conversation, _, _ = parse_chat_messages(
messages,
_model_config(),
content_format="string",
)
prompt = _tokenizer().apply_chat_template(
conversation=conversation,
messages=messages,
tools=tools,
tokenize=False,
)
assert '<|DSML|parameter name="command" string="true">view' in prompt
assert '<|DSML|parameter name="path" string="true">/testbed' in prompt
assert 'parameter name="arguments"' not in prompt
@pytest.mark.parametrize("reasoning_effort", ["none", "low", "medium", "high"])
def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort):
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
enable_thinking=True,
reasoning_effort=reasoning_effort,
)
assert prompt.endswith("<|Assistant|><think>")
assert "Reasoning Effort: Absolute maximum" not in prompt
def test_deepseek_v4_preserves_reference_max_reasoning_effort():
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
enable_thinking=True,
reasoning_effort="max",
)
assert prompt.startswith(
"<|begin▁of▁sentence|>Reasoning Effort: Absolute maximum"
)
@pytest.mark.parametrize(
("case_id", "kwargs"),
[
(1, {"thinking": True}),
(2, {"thinking": True}),
(3, {"thinking": True}),
(4, {}),
],
)
def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs):
prompt = _render_reference_case(case_id, **kwargs)
expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text()
assert prompt == expected
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for DeepSeekV4ToolParser."""
import json
from unittest.mock import MagicMock
from vllm.tool_parsers import ToolParserManager
from vllm.tool_parsers.deepseekv4_tool_parser import DeepSeekV4ToolParser
MOCK_TOKENIZER = MagicMock()
MOCK_TOKENIZER.get_vocab.return_value = {}
TC_START = "<|DSML|tool_calls>"
TC_END = "</|DSML|tool_calls>"
INV_START = '<|DSML|invoke name="'
INV_END = "</|DSML|invoke>"
PARAM_START = '<|DSML|parameter name="'
PARAM_END = "</|DSML|parameter>"
def make_parser(tools=None) -> DeepSeekV4ToolParser:
return DeepSeekV4ToolParser(MOCK_TOKENIZER, tools=tools)
def make_request(tools=None) -> MagicMock:
req = MagicMock()
req.tools = tools
return req
def build_tool_call(func_name: str, params: dict[str, str]) -> str:
param_strs = "".join(
f'{PARAM_START}{k}" string="true">{v}{PARAM_END}\n' for k, v in params.items()
)
return f'{TC_START}\n{INV_START}{func_name}">\n{param_strs}{INV_END}\n{TC_END}'
def stream(parser: DeepSeekV4ToolParser, full_text: str, chunk_size: int = 7):
deltas = []
previous_text = ""
for start in range(0, len(full_text), chunk_size):
delta_text = full_text[start : start + chunk_size]
current_text = previous_text + delta_text
delta = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[1],
request=make_request(),
)
previous_text = current_text
if delta is not None:
deltas.append(delta)
return deltas
def reconstruct_args(deltas, tool_index: int = 0) -> str:
fragments = []
for delta in deltas:
if delta.tool_calls:
for tool_call in delta.tool_calls:
if (
tool_call.index == tool_index
and tool_call.function
and tool_call.function.arguments
):
fragments.append(tool_call.function.arguments)
return "".join(fragments)
def test_registered():
assert ToolParserManager.get_tool_parser("deepseek_v4") is DeepSeekV4ToolParser
def test_extract_tool_calls():
parser = make_parser()
model_output = "Let me check. " + build_tool_call(
"get_weather", {"location": "Beijing", "unit": "celsius"}
)
result = parser.extract_tool_calls(model_output, make_request())
assert result.tools_called
assert result.content == "Let me check. "
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call.function.name == "get_weather"
assert json.loads(tool_call.function.arguments) == {
"location": "Beijing",
"unit": "celsius",
}
def test_function_calls_block_is_not_accepted():
parser = make_parser()
model_output = build_tool_call("search", {"query": "vllm"}).replace(
"tool_calls", "function_calls"
)
result = parser.extract_tool_calls(model_output, make_request())
assert not result.tools_called
assert result.content == model_output
def test_streaming_extracts_complete_invokes():
parser = make_parser()
full_text = build_tool_call("search", {"query": "deepseek v4"})
deltas = stream(parser, full_text, chunk_size=5)
names = [
tool_call.function.name
for delta in deltas
if delta.tool_calls
for tool_call in delta.tool_calls
]
assert names == ["search"]
assert json.loads(reconstruct_args(deltas)) == {"query": "deepseek v4"}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.v1.attention.utils import create_vllm_config
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size():
"""Regression test: DeepseekV4 compression path must compute slot_mapping from
compressed positions, not reuse the uncompressed common metadata mapping.
"""
device = torch.device("cuda")
# storage_block_size = block_size // compress_ratio = 256 // 4 = 64
kv_cache_spec = MLAAttentionSpec(
block_size=256,
num_kv_heads=1,
head_size=128,
dtype=torch.bfloat16,
compress_ratio=4,
)
vllm_config = create_vllm_config(max_model_len=1024)
builder = DeepseekV32IndexerMetadataBuilder(
kv_cache_spec=kv_cache_spec,
layer_names=["dummy"],
vllm_config=vllm_config,
device=device,
)
# Construct a single request where:
# - num_computed = 240 (=> compressed_pos_start = 60)
# - query_len = 40 (=> num_groups = 10)
# => compressed positions are 60..69 which cross the storage block boundary at 64.
query_start_loc = torch.tensor([0, 40], dtype=torch.int32, device=device)
query_start_loc_cpu = query_start_loc.cpu()
seq_lens = torch.tensor([280], dtype=torch.int32, device=device) # 240 + 40
# Two blocks: compressed positions 0..63 map to block 5, 64..127 map to block 7.
block_table_tensor = torch.tensor([[5, 7]], dtype=torch.int32, device=device)
# Dummy uncompressed slot mapping (length == uncompressed num_actual_tokens).
slot_mapping = torch.full((40,), -123, dtype=torch.int64, device=device)
common = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens.cpu(),
num_reqs=1,
num_actual_tokens=40,
max_query_len=40,
max_seq_len=280,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)
md = builder.build(common_prefix_len=0, common_attn_metadata=common)
# The compressed slot_mapping retains the original uncompressed size (40).
# Only every compress_ratio-th position gets a valid slot; the rest are -1.
assert md.slot_mapping.numel() == 40
valid_slots = md.slot_mapping[md.slot_mapping >= 0]
assert valid_slots.numel() == 10 # 40 tokens / compress_ratio 4
storage_bs = kv_cache_spec.storage_block_size # 64
# Compressed positions 60..63 land in block 5, positions 64..69 in block 7.
expected = torch.tensor(
[
5 * storage_bs + 60,
5 * storage_bs + 61,
5 * storage_bs + 62,
5 * storage_bs + 63,
]
+ [
7 * storage_bs + 0,
7 * storage_bs + 1,
7 * storage_bs + 2,
7 * storage_bs + 3,
7 * storage_bs + 4,
7 * storage_bs + 5,
],
dtype=torch.int64,
device=device,
)
torch.testing.assert_close(valid_slots, expected)
...@@ -1855,10 +1855,11 @@ def test_generate_scheduler_kv_cache_config(): ...@@ -1855,10 +1855,11 @@ def test_generate_scheduler_kv_cache_config():
def new_mla_spec(cache_dtype_str=None): def new_mla_spec(cache_dtype_str=None):
# head_size = kv_lora_rank(512) + qk_rope_head_dim(64) = 576
return MLAAttentionSpec( return MLAAttentionSpec(
block_size=16, block_size=16,
num_kv_heads=16, num_kv_heads=1,
head_size=64, head_size=576,
dtype=torch.float32, dtype=torch.float32,
cache_dtype_str=cache_dtype_str, cache_dtype_str=cache_dtype_str,
) )
......
...@@ -557,19 +557,19 @@ def test_prefill_hybrid_model_eagle(): ...@@ -557,19 +557,19 @@ def test_prefill_hybrid_model_eagle():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == num_full_blocks assert len(req1.block_hashes) == num_full_blocks
assert computed_blocks.get_block_ids() == ( assert computed_blocks.get_block_ids() == (
[1, 2, 3, 4], [1, 2, 3, 4, 5],
[0, 9, 10, 11], [0, 0, 10, 11, 12],
[0, 16, 17, 18], [0, 0, 17, 18, 19],
) )
assert num_computed_tokens == 4 * block_size assert num_computed_tokens == 5 * block_size
num_new_tokens = len(all_token_ids) - num_computed_tokens num_new_tokens = len(all_token_ids) - num_computed_tokens
blocks = manager.allocate_slots( blocks = manager.allocate_slots(
req1, num_new_tokens, num_computed_tokens, computed_blocks req1, num_new_tokens, num_computed_tokens, computed_blocks
) )
assert blocks is not None and blocks.get_block_ids() == ( assert blocks is not None and blocks.get_block_ids() == (
[22, 23, 24], [22, 23],
[25, 26, 27], [24, 25],
[28, 29, 30], [26, 27],
) )
for block_per_group in computed_blocks.blocks: for block_per_group in computed_blocks.blocks:
for block in block_per_group: for block in block_per_group:
...@@ -591,7 +591,7 @@ def test_prefill_hybrid_model_eagle(): ...@@ -591,7 +591,7 @@ def test_prefill_hybrid_model_eagle():
make_block_hash_with_group_id(block_hashes[0], 1), make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2), make_block_hash_with_group_id(block_hashes[0], 2),
], ],
4, 5,
) )
# Evict the first block of full attention, makes total cache miss. # Evict the first block of full attention, makes total cache miss.
...@@ -605,7 +605,7 @@ def test_prefill_hybrid_model_eagle(): ...@@ -605,7 +605,7 @@ def test_prefill_hybrid_model_eagle():
0, 0,
) )
# Evict the last block of all layers, reduces the hit length to 3. # Evict the last block of all layers, reduces the hit length to 4.
_test_partial_request_hit( _test_partial_request_hit(
manager, manager,
block_size, block_size,
...@@ -617,10 +617,10 @@ def test_prefill_hybrid_model_eagle(): ...@@ -617,10 +617,10 @@ def test_prefill_hybrid_model_eagle():
make_block_hash_with_group_id(block_hashes[-1], 1), make_block_hash_with_group_id(block_hashes[-1], 1),
make_block_hash_with_group_id(block_hashes[-1], 2), make_block_hash_with_group_id(block_hashes[-1], 2),
], ],
3, 4,
) )
# Evict the last block of full attention, reduces the hit length to 3. # Evict the last block of full attention, reduces the hit length to 4.
_test_partial_request_hit( _test_partial_request_hit(
manager, manager,
block_size, block_size,
...@@ -628,7 +628,7 @@ def test_prefill_hybrid_model_eagle(): ...@@ -628,7 +628,7 @@ def test_prefill_hybrid_model_eagle():
"5", "5",
all_token_ids, all_token_ids,
[make_block_hash_with_group_id(block_hashes[-1], 0)], [make_block_hash_with_group_id(block_hashes[-1], 0)],
3, 4,
) )
# Since the last block of full attention is dropped for eagle, evict # Since the last block of full attention is dropped for eagle, evict
...@@ -655,12 +655,11 @@ def test_prefill_hybrid_model_eagle(): ...@@ -655,12 +655,11 @@ def test_prefill_hybrid_model_eagle():
3, 3,
) )
# Evict different set of blocks for full attention and sliding window makes # Evict different set of blocks for full attention and sliding window.
# total cache miss. # Full loses its last block so it drops to 4 full blocks after the eagle
# The cache hit length of full attention is 4 * block_size. # pop; SWA lost block 0 (outside the sliding window of the final hit),
# The cache hit length of sliding window is 3 * block_size. # which is not required for the K+1 anchor at position 4. Coordinated
# Then it is cache miss as the two type of layers # single-drop aligns both groups at hit=4.
# have different hit length.
_test_partial_request_hit( _test_partial_request_hit(
manager, manager,
block_size, block_size,
...@@ -672,7 +671,7 @@ def test_prefill_hybrid_model_eagle(): ...@@ -672,7 +671,7 @@ def test_prefill_hybrid_model_eagle():
make_block_hash_with_group_id(block_hashes[0], 1), make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2), make_block_hash_with_group_id(block_hashes[0], 2),
], ],
0, 4,
) )
...@@ -893,7 +892,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): ...@@ -893,7 +892,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
# - 2 groups: 1 full + 1 other # - 2 groups: 1 full + 1 other
_EAGLE_HYBRID_MODEL_TEST_CASES = [ _EAGLE_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other # 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"), pytest.param(["full", "sliding_window"], 3, id="2g-full+sw"),
] ]
......
...@@ -1892,6 +1892,7 @@ def create_scheduler_with_priority( ...@@ -1892,6 +1892,7 @@ def create_scheduler_with_priority(
log_stats=True, log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config), structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size, block_size=block_size,
hash_block_size=block_size,
) )
...@@ -4008,6 +4009,7 @@ def _create_encoder_decoder_scheduler( ...@@ -4008,6 +4009,7 @@ def _create_encoder_decoder_scheduler(
vllm_config=vllm_config, vllm_config=vllm_config,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
block_size=block_size, block_size=block_size,
hash_block_size=block_size,
structured_output_manager=StructuredOutputManager(vllm_config), structured_output_manager=StructuredOutputManager(vllm_config),
) )
......
...@@ -91,8 +91,10 @@ def test_basic_interface(): ...@@ -91,8 +91,10 @@ def test_basic_interface():
assert request_id in kv_connector_metadata.reqs_to_recv["my-engine-id"] assert request_id in kv_connector_metadata.reqs_to_recv["my-engine-id"]
req_meta = kv_connector_metadata.reqs_to_recv["my-engine-id"][request_id] req_meta = kv_connector_metadata.reqs_to_recv["my-engine-id"][request_id]
# local_block_ids is list[list[int]] (per-group); flatten for comparison.
all_block_ids = [bid for group in req_meta.local_block_ids for bid in group]
for block_id, block in zip( for block_id, block in zip(
req_meta.local_block_ids, all_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id request_id
], ],
...@@ -228,15 +230,15 @@ def test_scheduler_request_finished(): ...@@ -228,15 +230,15 @@ def test_scheduler_request_finished():
# Case: Capped length (Successful prefill, need to send to decoder) # Case: Capped length (Successful prefill, need to send to decoder)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[10, 11]) delay_free, _ = scheduler_connector.request_finished(request, block_ids=([10, 11],))
assert delay_free is True assert delay_free is True
assert "id-1" in scheduler_connector._reqs_need_send assert "id-1" in scheduler_connector._reqs_need_send
assert scheduler_connector._reqs_need_send["id-1"][1] == [10, 11] assert scheduler_connector._reqs_need_send["id-1"][1] == [[10, 11]]
# Case: Aborted (No need to transfer, free blocks immediately) # Case: Aborted (No need to transfer, free blocks immediately)
scheduler_connector._reqs_need_send.clear() scheduler_connector._reqs_need_send.clear()
request.status = RequestStatus.FINISHED_ABORTED request.status = RequestStatus.FINISHED_ABORTED
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[12]) delay_free, _ = scheduler_connector.request_finished(request, block_ids=([12],))
assert delay_free is False assert delay_free is False
assert len(scheduler_connector._reqs_need_send) == 0 assert len(scheduler_connector._reqs_need_send) == 0
assert "id-1" in scheduler_connector._reqs_not_processed assert "id-1" in scheduler_connector._reqs_not_processed
...@@ -334,7 +336,7 @@ async def test_kv_producer(monkeypatch): ...@@ -334,7 +336,7 @@ async def test_kv_producer(monkeypatch):
send_meta = SendBlockMeta( send_meta = SendBlockMeta(
p_req_id="p-req-1", p_req_id="p-req-1",
transfer_id=transfer_id, transfer_id=transfer_id,
local_block_ids=[10, 11], local_block_ids=[[10, 11]],
ready=asyncio.Event(), ready=asyncio.Event(),
) )
prefill_worker.reqs_need_send[transfer_id] = send_meta prefill_worker.reqs_need_send[transfer_id] = send_meta
...@@ -346,7 +348,7 @@ async def test_kv_producer(monkeypatch): ...@@ -346,7 +348,7 @@ async def test_kv_producer(monkeypatch):
remote_port=54321, remote_port=54321,
remote_tp_size=1, remote_tp_size=1,
remote_tp_rank=0, remote_tp_rank=0,
req_blocks={"d-req-1": (transfer_id, [20, 21])}, req_blocks={"d-req-1": (transfer_id, [[20, 21]])},
kv_caches_base_addr=[0x2000], kv_caches_base_addr=[0x2000],
block_lens=[block_len], block_lens=[block_len],
) )
...@@ -389,7 +391,7 @@ async def test_kv_producer(monkeypatch): ...@@ -389,7 +391,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0 send_meta.sent = 0
send_meta.ready.set() send_meta.ready.set()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20]) xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20]])
# Worker processes the consumer's request # Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
# Verify transfer parameters are correct: 11 to 20 # Verify transfer parameters are correct: 11 to 20
...@@ -407,7 +409,7 @@ async def test_kv_producer(monkeypatch): ...@@ -407,7 +409,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0 send_meta.sent = 0
send_meta.ready.set() send_meta.ready.set()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21, 22]) xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21, 22]])
# Worker processes the consumer's request # Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
# This should not be called because error. # This should not be called because error.
...@@ -424,7 +426,7 @@ async def test_kv_producer(monkeypatch): ...@@ -424,7 +426,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0 send_meta.sent = 0
send_meta.ready.clear() send_meta.ready.clear()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21]) xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]])
# Worker processes the consumer's request # Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
# This should not be called because timeout. # This should not be called because timeout.
...@@ -443,7 +445,7 @@ async def test_kv_producer(monkeypatch): ...@@ -443,7 +445,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0 send_meta.sent = 0
send_meta.ready.set() send_meta.ready.set()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21]) xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]])
# Worker processes the consumer's request # Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
mock_send_blocks.assert_called_once() mock_send_blocks.assert_called_once()
...@@ -481,7 +483,7 @@ async def test_kv_consumuer(monkeypatch): ...@@ -481,7 +483,7 @@ async def test_kv_consumuer(monkeypatch):
"d-req-1": PullReqMeta( "d-req-1": PullReqMeta(
d_req_id="d-req-1", d_req_id="d-req-1",
transfer_id="xfer-req-1", transfer_id="xfer-req-1",
local_block_ids=[100, 101], local_block_ids=[[100, 101]],
remote_engine_id="p-engine", remote_engine_id="p-engine",
remote_bootstrap_addr="http://bootstrap:33333", remote_bootstrap_addr="http://bootstrap:33333",
pull_tasks_count=1, pull_tasks_count=1,
...@@ -514,7 +516,7 @@ async def test_kv_consumuer(monkeypatch): ...@@ -514,7 +516,7 @@ async def test_kv_consumuer(monkeypatch):
assert sent_meta.remote_hostname == "127.0.0.1" assert sent_meta.remote_hostname == "127.0.0.1"
assert sent_meta.remote_port == 54321 assert sent_meta.remote_port == 54321
assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [100, 101]) assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [[100, 101]])
# Verify internal state is updated correctly. # Verify internal state is updated correctly.
assert "d-req-1" in decode_worker.finished_recving_reqs assert "d-req-1" in decode_worker.finished_recving_reqs
...@@ -538,7 +540,7 @@ async def test_worker_get_finished_timeout(monkeypatch): ...@@ -538,7 +540,7 @@ async def test_worker_get_finished_timeout(monkeypatch):
prefill_worker.reqs_need_send["tx-expired"] = SendBlockMeta( prefill_worker.reqs_need_send["tx-expired"] = SendBlockMeta(
p_req_id="p-req-expired", p_req_id="p-req-expired",
transfer_id="tx-expired", transfer_id="tx-expired",
local_block_ids=[1, 2], local_block_ids=[[1, 2]],
ready=MagicMock(), ready=MagicMock(),
expire_time=time.perf_counter() - 100, expire_time=time.perf_counter() - 100,
) )
...@@ -547,7 +549,7 @@ async def test_worker_get_finished_timeout(monkeypatch): ...@@ -547,7 +549,7 @@ async def test_worker_get_finished_timeout(monkeypatch):
prefill_worker.reqs_need_send["tx-active"] = SendBlockMeta( prefill_worker.reqs_need_send["tx-active"] = SendBlockMeta(
p_req_id="p-req-active", p_req_id="p-req-active",
transfer_id="tx-active", transfer_id="tx-active",
local_block_ids=[3, 4], local_block_ids=[[3, 4]],
ready=MagicMock(), ready=MagicMock(),
expire_time=time.perf_counter() + 100, expire_time=time.perf_counter() + 100,
) )
...@@ -703,7 +705,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -703,7 +705,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
prefill_worker.sender_loop = asyncio.get_event_loop() prefill_worker.sender_loop = asyncio.get_event_loop()
transfer_id = "xfer-hetero-1" transfer_id = "xfer-hetero-1"
local_block_ids = [10, 11] local_block_ids = [[10, 11]]
send_meta = SendBlockMeta( send_meta = SendBlockMeta(
p_req_id="p-req-h1", p_req_id="p-req-h1",
transfer_id=transfer_id, transfer_id=transfer_id,
...@@ -720,9 +722,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -720,9 +722,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
mock_socket.send_multipart = AsyncMock() mock_socket.send_multipart = AsyncMock()
identity = b"consumer-hetero" identity = b"consumer-hetero"
# Assign different remote block IDs per D rank # Assign different remote block IDs per D rank (nested per-group)
d_rank_remote_blocks = { d_rank_remote_blocks = {
rank: [20 + i * 10, 21 + i * 10] for i, rank in enumerate(target_d_ranks) rank: [[20 + i * 10, 21 + i * 10]] for i, rank in enumerate(target_d_ranks)
} }
with patch.object( with patch.object(
...@@ -757,11 +759,15 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -757,11 +759,15 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
dst_ptrs = call_args[2] dst_ptrs = call_args[2]
lengths = call_args[3] lengths = call_args[3]
# Flatten nested per-group block IDs for assertions
flat_local = [b for g in local_block_ids for b in g]
flat_remote = [b for g in remote_block_ids for b in g]
# Heterogeneous TP: blocks cannot be coalesced because # Heterogeneous TP: blocks cannot be coalesced because
# local and remote block_lens differ # local and remote block_lens differ
assert len(src_ptrs) == len(local_block_ids) assert len(src_ptrs) == len(flat_local)
assert len(dst_ptrs) == len(local_block_ids) assert len(dst_ptrs) == len(flat_local)
assert len(lengths) == len(local_block_ids) assert len(lengths) == len(flat_local)
# Compute expected offsets based on TP ratio # Compute expected offsets based on TP ratio
if d_tp_size <= P_TP_SIZE: if d_tp_size <= P_TP_SIZE:
...@@ -775,9 +781,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -775,9 +781,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
expected_dst_off = 0 expected_dst_off = 0
expected_xfer_len = remote_block_len expected_xfer_len = remote_block_len
for idx, (lblk, rblk) in enumerate( for idx, (lblk, rblk) in enumerate(zip(flat_local, flat_remote)):
zip(local_block_ids, remote_block_ids)
):
assert src_ptrs[idx] == ( assert src_ptrs[idx] == (
0x1000 + lblk * local_block_len + expected_src_off 0x1000 + lblk * local_block_len + expected_src_off
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for MooncakeConnector HMA (Hybrid Memory Architecture) support.
Covers sliding-window clipping, multi-group metadata shape, multi-group
send trimming, and group-count invariant checking in _build_transfer_params.
"""
import asyncio
from unittest.mock import patch
import pytest
from vllm.config import set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import (
KVConnectorRole,
MooncakeConnector,
MooncakeConnectorMetadata,
MooncakeConnectorScheduler,
MooncakeXferMetadata,
SendBlockMeta,
TransferRegion,
)
from .test_mooncake_connector import FakeMooncakeWrapper, patch_worker_dependencies
from .utils import create_request, create_vllm_config, make_kv_cache_config
# ---------------------------------------------------------------------------
# test_sw_sizes: blocks_per_sw computed from KVCacheConfig
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"swa_enabled,expected_blocks_per_sw",
[
# SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128+1)
(True, [0, 128 + 1]),
# SWA disabled: only FullAttentionSpec (0)
(False, [0]),
],
)
def test_sw_sizes(swa_enabled, expected_blocks_per_sw):
"""blocks_per_sw is correctly computed based on SWA enabled/disabled."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
# Override so HMA detection works
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=swa_enabled, sw_size=2048
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler.blocks_per_sw == expected_blocks_per_sw
# ---------------------------------------------------------------------------
# test_is_hma_required: derived from kv_cache_config groups
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"swa_enabled,disable_hma,expected_is_hma",
[
(True, False, True), # SWA group present, HMA enabled
(True, True, False), # SWA group present, but HMA disabled
(False, False, False), # FA only, HMA not needed
],
)
def test_is_hma_required(swa_enabled, disable_hma, expected_is_hma):
"""_is_hma_required is correctly derived from kv_cache_config."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = disable_hma
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=swa_enabled
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler._is_hma_required is expected_is_hma
# ---------------------------------------------------------------------------
# test_get_sw_clipped_blocks: sliding-window clipping logic
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
def test_get_sw_clipped_blocks():
"""get_sw_clipped_blocks clips SWA group but keeps FA group intact."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
# SW=128 tokens → 128/16 = 8 blocks + 1 = 9 blocks_per_sw
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=True, sw_size=128
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler.blocks_per_sw == [0, 9]
# FA group: 20 blocks, SW group: 20 blocks (exceeds window)
fa_blocks = list(range(20))
sw_blocks = list(range(100, 120))
block_ids = (fa_blocks, sw_blocks)
clipped = scheduler.get_sw_clipped_blocks(block_ids)
# FA: untouched (blocks_per_sw[0] = 0)
assert clipped[0] == fa_blocks
# SW: clipped to last 9 blocks
assert clipped[1] == sw_blocks[-9:]
assert len(clipped[1]) == 9
@pytest.mark.cpu_test
def test_get_sw_clipped_blocks_noop_no_hma():
"""get_sw_clipped_blocks is a no-op when HMA is not required."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
# FA only → _is_hma_required = False
kv_cache_config = make_kv_cache_config(block_size=block_size, swa_enabled=False)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler._is_hma_required is False
block_ids = ([1, 2, 3],)
clipped = scheduler.get_sw_clipped_blocks(block_ids)
assert clipped == [[1, 2, 3]]
# ---------------------------------------------------------------------------
# test_metadata_hma_block_ids: MooncakeConnectorMetadata stores per-group IDs
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
def test_metadata_hma_block_ids():
"""MooncakeConnectorMetadata.add_new_req stores per-group block IDs."""
metadata = MooncakeConnectorMetadata()
# FA group: 6 blocks, SW group: 3 blocks (clipped)
fa_blocks = [0, 1, 2, 3, 4, 5]
sw_blocks = [10, 11, 12]
# Test recv path
metadata.add_new_req(
request_id="recv-req",
local_block_ids=[fa_blocks, sw_blocks],
kv_transfer_params={
"transfer_id": "recv-req",
"remote_engine_id": "remote-engine",
"remote_bootstrap_addr": "http://bootstrap:33333",
},
load_remote_cache=True,
)
assert "recv-req" in metadata.reqs_to_recv["remote-engine"]
req_meta = metadata.reqs_to_recv["remote-engine"]["recv-req"]
assert len(req_meta.local_block_ids) == 2
assert req_meta.local_block_ids[0] == fa_blocks
assert req_meta.local_block_ids[1] == sw_blocks
# Test send path
metadata.add_new_req(
request_id="send-req",
local_block_ids=[fa_blocks, sw_blocks],
kv_transfer_params={
"transfer_id": "send-req",
},
load_remote_cache=False,
)
assert "send-req" in metadata.reqs_to_send
transfer_id, stored_blocks = metadata.reqs_to_send["send-req"]
assert transfer_id == "send-req"
assert len(stored_blocks) == 2
assert stored_blocks[0] == fa_blocks
assert stored_blocks[1] == sw_blocks
# ---------------------------------------------------------------------------
# test_build_transfer_params_multi_group_trimming
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake"
".mooncake_connector.TransferEngine",
FakeMooncakeWrapper,
)
async def test_build_transfer_params_multi_group_trimming(monkeypatch):
"""_build_transfer_params trims per-group blocks when local > remote."""
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector", kv_role="kv_producer"
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker
block_len = 4096
# Call _build_transfer_params directly (avoids send_kv_to_decode
# async event loop complexity).
transfer_id = "xfer-hma-trim"
send_meta = SendBlockMeta(
p_req_id="p-trim",
transfer_id=transfer_id,
# FA: 4 blocks, SW: 3 blocks (producer has more)
local_block_ids=[[10, 11, 12, 13], [20, 21, 22]],
ready=asyncio.Event(),
)
xfer_meta = MooncakeXferMetadata(
remote_hostname="consumer-host",
remote_port=54321,
remote_tp_size=1,
remote_tp_rank=0,
req_blocks={
"d-trim": (
transfer_id,
# FA: 2 blocks, SW: 2 blocks (consumer needs fewer)
[[30, 31], [40, 41]],
)
},
kv_caches_base_addr=[0x2000],
block_lens=[block_len],
)
local_regions = [
TransferRegion(
base_addr=0x1000, block_len=block_len, kv_block_len=block_len
),
]
remote_regions = [
TransferRegion(
base_addr=0x2000, block_len=block_len, kv_block_len=block_len
),
]
ready_reqs = [("d-trim", send_meta)]
(
src_ptrs,
dst_ptrs,
lengths,
err_reqs,
err_msg,
) = await worker._build_transfer_params(
ready_reqs, xfer_meta, local_regions, remote_regions
)
# No errors
assert err_reqs == []
assert err_msg is None
# After trimming: FA [10..13] → last 2 → [12,13]; SW [20..22] → last 2 → [21,22]
# Flattened: [12,13,21,22] = 4 blocks → coalesced into some transfers
assert len(src_ptrs) > 0
assert len(dst_ptrs) == len(src_ptrs)
assert len(lengths) == len(src_ptrs)
worker.shutdown()
# ---------------------------------------------------------------------------
# test_build_transfer_params_group_count_mismatch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake"
".mooncake_connector.TransferEngine",
FakeMooncakeWrapper,
)
async def test_build_transfer_params_group_count_mismatch(monkeypatch):
"""_build_transfer_params reports an error when group counts differ."""
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector", kv_role="kv_producer"
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker
block_len = 4096
transfer_id = "xfer-mismatch"
send_meta = SendBlockMeta(
p_req_id="p-mismatch",
transfer_id=transfer_id,
# Producer has 2 groups
local_block_ids=[[10, 11], [20, 21]],
ready=asyncio.Event(),
)
# Consumer has only 1 group — group count mismatch
xfer_meta = MooncakeXferMetadata(
remote_hostname="consumer-host",
remote_port=54321,
remote_tp_size=1,
remote_tp_rank=0,
req_blocks={
"d-mismatch": (transfer_id, [[30, 31]]),
},
kv_caches_base_addr=[0x2000],
block_lens=[block_len],
)
local_regions = [
TransferRegion(
base_addr=0x1000, block_len=block_len, kv_block_len=block_len
),
]
remote_regions = [
TransferRegion(
base_addr=0x2000, block_len=block_len, kv_block_len=block_len
),
]
ready_reqs = [("d-mismatch", send_meta)]
(
src_ptrs,
dst_ptrs,
lengths,
err_reqs,
err_msg,
) = await worker._build_transfer_params(
ready_reqs, xfer_meta, local_regions, remote_regions
)
# Mismatched req is reported via err_reqs/err_msg with no transfers built.
assert err_reqs == ["d-mismatch"]
assert err_msg == "KV group count mismatch"
assert src_ptrs == []
assert dst_ptrs == []
assert lengths == []
worker.shutdown()
# ---------------------------------------------------------------------------
# test_request_finished_with_hma_groups
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
def test_request_finished_with_hma_groups():
"""request_finished correctly handles per-group block_ids."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_producer",
block_size=block_size,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=True, sw_size=128
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
request = create_request(request_id=1, do_remote_decode=True)
request.kv_transfer_params["transfer_id"] = request.request_id
from vllm.v1.request import RequestStatus
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
# 2 groups: FA with 10 blocks, SW with 20 blocks (will be clipped)
fa_blocks = list(range(10))
sw_blocks = list(range(100, 120))
block_ids = (fa_blocks, sw_blocks)
delay_free, _ = scheduler.request_finished(request, block_ids)
assert delay_free is True
assert request.request_id in scheduler._reqs_need_send
_, stored_blocks = scheduler._reqs_need_send[request.request_id]
# FA: untouched
assert stored_blocks[0] == fa_blocks
# SW: clipped to last 9 blocks (sw_size=128, block_size=16 → 8+1=9)
assert stored_blocks[1] == sw_blocks[-9:]
...@@ -76,6 +76,7 @@ def create_scheduler() -> Scheduler: ...@@ -76,6 +76,7 @@ def create_scheduler() -> Scheduler:
log_stats=True, log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config), structured_output_manager=StructuredOutputManager(vllm_config),
block_size=16, block_size=16,
hash_block_size=16,
) )
......
...@@ -7,7 +7,7 @@ set -e ...@@ -7,7 +7,7 @@ set -e
# Default values # Default values
# Keep DEEPGEMM_GIT_REF in sync with cmake/external_projects/deepgemm.cmake # Keep DEEPGEMM_GIT_REF in sync with cmake/external_projects/deepgemm.cmake
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
DEEPGEMM_GIT_REF="477618cd51baffca09c4b0b87e97c03fe827ef03" DEEPGEMM_GIT_REF="891d57b4db1071624b5c8fa0d1e51cb317fa709f"
WHEEL_DIR="" WHEEL_DIR=""
# Parse command line arguments # Parse command line arguments
......
...@@ -404,10 +404,24 @@ def rotary_embedding( ...@@ -404,10 +404,24 @@ def rotary_embedding(
head_size: int, head_size: int,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool, is_neox: bool,
rope_dim_offset: int = 0,
inverse: bool = False,
) -> None: ) -> None:
torch.ops._C.rotary_embedding( if rope_dim_offset == 0 and not inverse:
positions, query, key, head_size, cos_sin_cache, is_neox torch.ops._C.rotary_embedding(
) positions, query, key, head_size, cos_sin_cache, is_neox
)
else:
torch.ops._C.rotary_embedding(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox,
rope_dim_offset,
inverse,
)
# layer norm ops # layer norm ops
...@@ -2503,6 +2517,30 @@ def topk_sigmoid( ...@@ -2503,6 +2517,30 @@ def topk_sigmoid(
) )
def topk_hash_softplus_sqrt(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.topk_softplus_sqrt(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)
def grouped_topk( def grouped_topk(
scores: torch.Tensor, scores: torch.Tensor,
num_expert_group: int, num_expert_group: int,
......
...@@ -51,6 +51,9 @@ class AttentionConfig: ...@@ -51,6 +51,9 @@ class AttentionConfig:
use_prefill_query_quantization: bool = False use_prefill_query_quantization: bool = False
"""If set, quantize query for attention in prefill.""" """If set, quantize query for attention in prefill."""
use_fp4_indexer_cache: bool = False
"""If set, use fp4 indexer cache for dsv32 family model (not support yet)"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
Provide a hash that uniquely identifies all the configs Provide a hash that uniquely identifies all the configs
......
...@@ -51,6 +51,18 @@ class CacheConfig: ...@@ -51,6 +51,18 @@ class CacheConfig:
"""Whether block_size was explicitly provided. Derived automatically.""" """Whether block_size was explicitly provided. Derived automatically."""
user_specified_mamba_block_size: bool = field(default=False, init=False) user_specified_mamba_block_size: bool = field(default=False, init=False)
"""Whether mamba_block_size was explicitly provided. Derived automatically.""" """Whether mamba_block_size was explicitly provided. Derived automatically."""
hash_block_size: SkipValidation[int] | None = None # type: ignore
"""Block size (in tokens) used for computing Request's block_hashes.
This can be set to a finer granularity than the physical KV cache block
sizes (e.g. 8) as long as every KV cache group's `block_size` is divisible
by it. This enables prefix-caching keys to be computed at the finest common
granularity and then merged for larger physical block sizes.
This config is not static default. If left unspecified, vLLM will choose a
default based on the resolved KV cache groups (typically the smallest KV
cache block size when there are multiple groups).
"""
gpu_memory_utilization: float = Field(default=0.92, gt=0, le=1) gpu_memory_utilization: float = Field(default=0.92, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can """The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
...@@ -182,6 +194,8 @@ class CacheConfig: ...@@ -182,6 +194,8 @@ class CacheConfig:
"num_gpu_blocks_override", "num_gpu_blocks_override",
"enable_prefix_caching", "enable_prefix_caching",
"prefix_caching_hash_algo", "prefix_caching_hash_algo",
# Prefix-caching implementation detail (doesn't affect compiled graph).
"hash_block_size",
"mamba_page_size_padded", "mamba_page_size_padded",
"user_specified_block_size", "user_specified_block_size",
"user_specified_mamba_block_size", "user_specified_mamba_block_size",
......
...@@ -749,6 +749,7 @@ class CompilationConfig: ...@@ -749,6 +749,7 @@ class CompilationConfig:
"vllm::kda_attention", "vllm::kda_attention",
"vllm::sparse_attn_indexer", "vllm::sparse_attn_indexer",
"vllm::rocm_aiter_sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer",
"vllm::deepseek_v4_attention",
] ]
def compute_hash(self) -> str: def compute_hash(self) -> str:
......
...@@ -50,7 +50,7 @@ class IrOpPriorityConfig: ...@@ -50,7 +50,7 @@ class IrOpPriorityConfig:
name: { name: {
provider: IrOp.registry[name].impls[provider].uuid() for provider in p provider: IrOp.registry[name].impls[provider].uuid() for provider in p
} }
for name, p in asdict(self).items() for name, p in asdict(self).items() # type: ignore[call-overload]
} }
return hash_factors(factors) return hash_factors(factors)
...@@ -77,7 +77,7 @@ class IrOpPriorityConfig: ...@@ -77,7 +77,7 @@ class IrOpPriorityConfig:
current_platform.import_ir_kernels() current_platform.import_ir_kernels()
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
for field in fields(self): for field in fields(self): # type: ignore[arg-type]
op_priority = getattr(self, field.name) op_priority = getattr(self, field.name)
assert op_priority is not None, ( assert op_priority is not None, (
f"IR op priority for {field.name} must be set" f"IR op priority for {field.name} must be set"
...@@ -98,7 +98,7 @@ class IrOpPriorityConfig: ...@@ -98,7 +98,7 @@ class IrOpPriorityConfig:
A helper to create an IrOpPriorityConfig where fields not specified in kwargs A helper to create an IrOpPriorityConfig where fields not specified in kwargs
use the given default list. use the given default list.
""" """
for field in fields(cls): for field in fields(cls): # type: ignore[arg-type]
if field.name not in kwargs: if field.name not in kwargs:
kwargs[field.name] = list(default) kwargs[field.name] = list(default)
...@@ -109,6 +109,7 @@ MoEBackend = Literal[ ...@@ -109,6 +109,7 @@ MoEBackend = Literal[
"auto", "auto",
"triton", "triton",
"deep_gemm", "deep_gemm",
"deep_gemm_mega_moe",
"cutlass", "cutlass",
"flashinfer_trtllm", "flashinfer_trtllm",
"flashinfer_cutlass", "flashinfer_cutlass",
...@@ -136,8 +137,9 @@ class KernelConfig: ...@@ -136,8 +137,9 @@ class KernelConfig:
"""Backend for MoE expert computation kernels. Available options: """Backend for MoE expert computation kernels. Available options:
- "auto": Automatically select the best backend based on model and hardware - "auto": Automatically select the best backend based on model and hardware
- "triton": Use Triton-based fused MoE kernels - "triton": Use Triton-based fused MoE kernels
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only) - "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)
- "deep_gemm_mega_moe": Use DeepGEMM mega MoE kernels
- "cutlass": Use vLLM CUTLASS kernels - "cutlass": Use vLLM CUTLASS kernels
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels - "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
......
...@@ -83,7 +83,7 @@ logger = init_logger(__name__) ...@@ -83,7 +83,7 @@ logger = init_logger(__name__)
RunnerOption = Literal["auto", RunnerType] RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify"] ConvertType = Literal["none", "embed", "classify"]
ConvertOption = Literal["auto", ConvertType] ConvertOption = Literal["auto", ConvertType]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"] TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32", "deepseek_v4"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[ LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
...@@ -134,6 +134,7 @@ class ModelConfig: ...@@ -134,6 +134,7 @@ class ModelConfig:
- "slow" will always use the slow tokenizer. - "slow" will always use the slow tokenizer.
- "mistral" will always use the tokenizer from `mistral_common`. - "mistral" will always use the tokenizer from `mistral_common`.
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`. - "deepseek_v32" will always use the tokenizer from `deepseek_v32`.
- "deepseek_v4" will always use the tokenizer from `deepseek_v4`.
- "qwen_vl" will always use the tokenizer from `qwen_vl`. - "qwen_vl" will always use the tokenizer from `qwen_vl`.
- Other custom values can be supported via plugins.""" - Other custom values can be supported via plugins."""
trust_remote_code: bool = False trust_remote_code: bool = False
...@@ -565,6 +566,8 @@ class ModelConfig: ...@@ -565,6 +566,8 @@ class ModelConfig:
self.tokenizer_mode = "qwen_vl" self.tokenizer_mode = "qwen_vl"
elif arch == "DeepseekV32ForCausalLM": elif arch == "DeepseekV32ForCausalLM":
self.tokenizer_mode = "deepseek_v32" self.tokenizer_mode = "deepseek_v32"
elif arch == "DeepseekV4ForCausalLM":
self.tokenizer_mode = "deepseek_v4"
if self.tokenizer_mode != "auto": if self.tokenizer_mode != "auto":
logger.info( logger.info(
...@@ -952,6 +955,7 @@ class ModelConfig: ...@@ -952,6 +955,7 @@ class ModelConfig:
# imports during override detection (e.g., MXFP4 imports Triton) # imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4", "mxfp4",
"gpt_oss_mxfp4", "gpt_oss_mxfp4",
"deepseek_v4_fp8",
"cpu_awq", "cpu_awq",
"humming", "humming",
"gguf", "gguf",
......
...@@ -287,13 +287,23 @@ class SpeculativeConfig: ...@@ -287,13 +287,23 @@ class SpeculativeConfig:
@staticmethod @staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
initial_architecture = hf_config.architectures[0] initial_architecture = hf_config.architectures[0]
if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"): if hf_config.model_type in (
"deepseek_v3",
"deepseek_v32",
"glm_moe_dsa",
):
hf_config.model_type = "deepseek_mtp" hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp": if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None) n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update( hf_config.update(
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]} {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
) )
if hf_config.model_type == "deepseek_v4":
hf_config.model_type = "deepseek_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["DeepSeekV4MTPModel"]}
)
if hf_config.model_type in ("pangu_ultra_moe"): if hf_config.model_type in ("pangu_ultra_moe"):
hf_config.model_type = "pangu_ultra_moe_mtp" hf_config.model_type = "pangu_ultra_moe_mtp"
if hf_config.model_type == "pangu_ultra_moe_mtp": if hf_config.model_type == "pangu_ultra_moe_mtp":
......
...@@ -29,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ...@@ -29,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorMetadata, KVConnectorMetadata,
KVConnectorRole, KVConnectorRole,
SupportsHMA,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import ( from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import (
MooncakeBootstrapServer, MooncakeBootstrapServer,
...@@ -43,10 +44,12 @@ from vllm.distributed.parallel_state import ( ...@@ -43,10 +44,12 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, SlidingWindowSpec
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
from vllm.v1.worker.utils import select_common_block_size from vllm.v1.worker.utils import select_common_block_size
...@@ -252,7 +255,7 @@ class MooncakeXferMetadata( ...@@ -252,7 +255,7 @@ class MooncakeXferMetadata(
remote_port: int remote_port: int
remote_tp_size: int remote_tp_size: int
remote_tp_rank: int remote_tp_rank: int
req_blocks: dict[ReqId, tuple[TransferId, list[int]]] req_blocks: dict[ReqId, tuple[TransferId, list[list[int]]]]
kv_caches_base_addr: list[int] kv_caches_base_addr: list[int]
block_lens: list[int] block_lens: list[int]
...@@ -280,7 +283,7 @@ class MooncakeXferResponse( ...@@ -280,7 +283,7 @@ class MooncakeXferResponse(
class PullReqMeta: class PullReqMeta:
d_req_id: ReqId d_req_id: ReqId
transfer_id: TransferId transfer_id: TransferId
local_block_ids: list[int] local_block_ids: list[list[int]]
remote_engine_id: EngineId remote_engine_id: EngineId
remote_bootstrap_addr: str remote_bootstrap_addr: str
# Set expire time to avoid infinitely sending requests. # Set expire time to avoid infinitely sending requests.
...@@ -293,7 +296,7 @@ class PullReqMeta: ...@@ -293,7 +296,7 @@ class PullReqMeta:
class SendBlockMeta: class SendBlockMeta:
p_req_id: ReqId p_req_id: ReqId
transfer_id: TransferId transfer_id: TransferId
local_block_ids: list[int] local_block_ids: list[list[int]]
ready: asyncio.Event ready: asyncio.Event
expire_time: float = float("inf") expire_time: float = float("inf")
need_send: int = 0 need_send: int = 0
...@@ -306,13 +309,13 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): ...@@ -306,13 +309,13 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
# Use (engine_id, dp_rank) to group reqs with same dp. # Use (engine_id, dp_rank) to group reqs with same dp.
# See comments in MooncakeBootstrapServer. # See comments in MooncakeBootstrapServer.
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict) self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {} self.reqs_to_send: dict[ReqId, tuple[TransferId, list[list[int]]]] = {}
self.reqs_not_processed: set[TransferId] = set() self.reqs_not_processed: set[TransferId] = set()
def add_new_req( def add_new_req(
self, self,
request_id: ReqId, request_id: ReqId,
local_block_ids: list[int], local_block_ids: list[list[int]],
kv_transfer_params: dict[str, Any], kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True, load_remote_cache: bool = True,
): ):
...@@ -330,7 +333,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): ...@@ -330,7 +333,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
self.reqs_to_send[request_id] = (transfer_id, local_block_ids) self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
class MooncakeConnector(KVConnectorBase_V1): class MooncakeConnector(KVConnectorBase_V1, SupportsHMA):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
...@@ -344,13 +347,18 @@ class MooncakeConnector(KVConnectorBase_V1): ...@@ -344,13 +347,18 @@ class MooncakeConnector(KVConnectorBase_V1):
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER: if role == KVConnectorRole.SCHEDULER:
assert kv_cache_config is not None, (
"kv_cache_config is required for SCHEDULER role"
)
self.connector_scheduler: MooncakeConnectorScheduler | None = ( self.connector_scheduler: MooncakeConnectorScheduler | None = (
MooncakeConnectorScheduler(vllm_config, self.engine_id) MooncakeConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
) )
self.connector_worker: MooncakeConnectorWorker | None = None self.connector_worker: MooncakeConnectorWorker | None = None
elif role == KVConnectorRole.WORKER: elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id) self.connector_worker = MooncakeConnectorWorker(
vllm_config, self.engine_id, kv_cache_config
)
@classmethod @classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig): def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
...@@ -401,6 +409,14 @@ class MooncakeConnector(KVConnectorBase_V1): ...@@ -401,6 +409,14 @@ class MooncakeConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
block_ids: list[int], block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, (block_ids,))
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]: ) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids) return self.connector_scheduler.request_finished(request, block_ids)
...@@ -445,8 +461,14 @@ class MooncakeConnector(KVConnectorBase_V1): ...@@ -445,8 +461,14 @@ class MooncakeConnector(KVConnectorBase_V1):
class MooncakeConnectorScheduler: class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods""" """Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str): def __init__(
self,
vllm_config: VllmConfig,
engine_id: str,
kv_cache_config: "KVCacheConfig",
):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
assert vllm_config.kv_transfer_config assert vllm_config.kv_transfer_config
self.is_kv_producer: bool = ( self.is_kv_producer: bool = (
...@@ -457,15 +479,49 @@ class MooncakeConnectorScheduler: ...@@ -457,15 +479,49 @@ class MooncakeConnectorScheduler:
) )
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id) logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
# Requests that need to start recv/send. # Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in # New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker. # the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_recv: dict[ReqId, tuple[Request, list[list[int]]]] = {}
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_send: dict[ReqId, tuple[Request, list[list[int]]]] = {}
# Reqs to remove from processed set because they're not to send after # Reqs to remove from processed set because they're not to send after
# remote prefill or aborted. # remote prefill or aborted.
self._reqs_not_processed: set[TransferId] = set() self._reqs_not_processed: set[TransferId] = set()
# Compute sliding window block counts per KV cache group.
sw_sizes_tokens: list[tuple[int, int]] = [
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
else (0, self.block_size)
for g in kv_cache_config.kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to
# conservatively account for boundary overlap.
self.blocks_per_sw = [
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
for n_tokens, block_size in sw_sizes_tokens
]
def get_sw_clipped_blocks(
self,
block_ids: tuple[list[int], ...] | list[list[int]],
) -> list[list[int]]:
"""Clip per-group block IDs to sliding window size."""
if len(block_ids) == 0 or not self._is_hma_required:
return list(block_ids)
return [
blocks[-self.blocks_per_sw[i] :] if self.blocks_per_sw[i] > 0 else blocks
for i, blocks in enumerate(block_ids)
]
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]: ) -> tuple[int, bool]:
...@@ -530,9 +586,12 @@ class MooncakeConnectorScheduler: ...@@ -530,9 +586,12 @@ class MooncakeConnectorScheduler:
# If remote_blocks and num_external_tokens = 0, we have # If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call # a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P. # send_notif in _read_blocks to free the memory on the P.
local_block_ids = ( unhashed_block_ids = (
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] blocks.get_unhashed_block_ids_all_groups()
if num_external_tokens > 0
else ()
) )
local_block_ids = self.get_sw_clipped_blocks(unhashed_block_ids)
# Get unhashed blocks to pull from remote. # Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (request, local_block_ids) self._reqs_need_recv[request.request_id] = (request, local_block_ids)
else: else:
...@@ -587,7 +646,7 @@ class MooncakeConnectorScheduler: ...@@ -587,7 +646,7 @@ class MooncakeConnectorScheduler:
def request_finished( def request_finished(
self, self,
request: "Request", request: "Request",
block_ids: list[int], block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]: ) -> tuple[bool, dict[str, Any] | None]:
""" """
Once a request is finished, determine whether request blocks Once a request is finished, determine whether request blocks
...@@ -630,10 +689,13 @@ class MooncakeConnectorScheduler: ...@@ -630,10 +689,13 @@ class MooncakeConnectorScheduler:
# TODO: check whether block_ids actually ever be 0. If not we could # TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below # remove the conditional below
delay_free_blocks = len(block_ids) > 0 delay_free_blocks = any(len(group) > 0 for group in block_ids)
if delay_free_blocks: if delay_free_blocks:
self._reqs_need_send[request.request_id] = (request, block_ids) self._reqs_need_send[request.request_id] = (
request,
self.get_sw_clipped_blocks(block_ids),
)
return delay_free_blocks, None return delay_free_blocks, None
...@@ -641,7 +703,12 @@ class MooncakeConnectorScheduler: ...@@ -641,7 +703,12 @@ class MooncakeConnectorScheduler:
class MooncakeConnectorWorker: class MooncakeConnectorWorker:
"""Implementation of Worker side methods""" """Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str): def __init__(
self,
vllm_config: VllmConfig,
engine_id: str,
kv_cache_config: "KVCacheConfig | None" = None,
):
if TransferEngine is None: if TransferEngine is None:
logger.error("Mooncake is not available") logger.error("Mooncake is not available")
raise RuntimeError("Mooncake is not available") raise RuntimeError("Mooncake is not available")
...@@ -752,6 +819,7 @@ class MooncakeConnectorWorker: ...@@ -752,6 +819,7 @@ class MooncakeConnectorWorker:
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.kv_cache_config = kv_cache_config
self.use_mla = self.model_config.use_mla self.use_mla = self.model_config.use_mla
self._sync_block_size_with_kernel() self._sync_block_size_with_kernel()
...@@ -1103,27 +1171,61 @@ class MooncakeConnectorWorker: ...@@ -1103,27 +1171,61 @@ class MooncakeConnectorWorker:
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
for d_req_id, send_meta in ready_reqs: for d_req_id, send_meta in ready_reqs:
_, remote_block_ids = agent_meta.req_blocks[d_req_id] _, remote_block_ids_per_group = agent_meta.req_blocks[d_req_id]
num_remote_blocks = len(remote_block_ids)
if num_remote_blocks == 0: if not remote_block_ids_per_group or all(
len(g) == 0 for g in remote_block_ids_per_group
):
continue continue
local_block_ids = send_meta.local_block_ids # Per-group partial hit trimming, then flatten.
# Partial prefix cache hit: just read uncomputed blocks. # With HMA, groups share the same KV tensor but use different
num_local_blocks = len(local_block_ids) # block ranges. We trim and concatenate so the coalescer and
if num_local_blocks < num_remote_blocks: # address math see one flat block list — same as non-HMA, but
# now including blocks from every group.
local_block_ids: list[int] = []
remote_block_ids: list[int] = []
has_block_error = False
if len(send_meta.local_block_ids) != len(remote_block_ids_per_group):
logger.error( logger.error(
"req %s: local blocks(%d) less than remote blocks(%d)!", "req %s: KV group count mismatch: local=%d, remote=%d",
d_req_id, d_req_id,
num_local_blocks, len(send_meta.local_block_ids),
num_remote_blocks, len(remote_block_ids_per_group),
) )
err_reqs.append(d_req_id)
if err_msg is None:
err_msg = "KV group count mismatch"
continue
for local_group, remote_group in zip(
send_meta.local_block_ids, remote_block_ids_per_group
):
n_local = len(local_group)
n_remote = len(remote_group)
if n_local < n_remote:
logger.error(
"req %s: local blocks(%d) < remote blocks(%d) "
"in a KV cache group",
d_req_id,
n_local,
n_remote,
)
has_block_error = True
break
if n_local > n_remote:
# Partial prefix cache hit: just read uncomputed blocks.
local_group = local_group[-n_remote:]
local_block_ids.extend(local_group)
remote_block_ids.extend(remote_group)
if has_block_error:
err_reqs.append(d_req_id) err_reqs.append(d_req_id)
if err_msg is None: if err_msg is None:
err_msg = "P num blocks less than D" err_msg = "P num blocks less than D"
continue continue
if num_local_blocks > num_remote_blocks:
local_block_ids = local_block_ids[-num_remote_blocks:] if not local_block_ids:
continue
# Group by indices # Group by indices
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous( group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
...@@ -1215,7 +1317,7 @@ class MooncakeConnectorWorker: ...@@ -1215,7 +1317,7 @@ class MooncakeConnectorWorker:
logger.debug( logger.debug(
"Sending kv_caches for request %s (%d blocks) to %s", "Sending kv_caches for request %s (%d blocks) to %s",
d_req_id, d_req_id,
num_remote_blocks, len(local_block_ids),
remote_session, remote_session,
) )
...@@ -1273,23 +1375,24 @@ class MooncakeConnectorWorker: ...@@ -1273,23 +1375,24 @@ class MooncakeConnectorWorker:
continue continue
seen_base_addresses.append(base_addr) seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.nbytes
if tensor_size_bytes is None: if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes tensor_size_bytes = cache.nbytes
self.num_blocks = cache.shape[0] self.num_blocks = cache.shape[0]
assert cache.shape[0] == self.num_blocks, ( assert cache.shape[0] == self.num_blocks, (
"All kv cache tensors must have the same number of blocks" "All kv cache tensors must have the same number of blocks"
) )
assert curr_tensor_size_bytes % self.num_blocks == 0, (
"Mooncake expects each kv cache tensor size to be " # Use stride-based block length so RDMA reaches the last
"divisible by the number of blocks." # block's padding (e.g. DeepseekV4 MLA alignment). stride(0)
) # reflects the actual byte distance between consecutive
self.block_len_per_layer.append( # blocks in GPU memory, which matches or exceeds the
curr_tensor_size_bytes // self.num_blocks # shape-based size.
) block_len = cache.stride(0) * cache.element_size()
self.block_len_per_layer.append(block_len)
kv_data_ptrs.append(base_addr) kv_data_ptrs.append(base_addr)
kv_data_lens.append(curr_tensor_size_bytes) kv_data_lens.append(self.num_blocks * block_len)
self.kv_caches_base_addr = seen_base_addresses self.kv_caches_base_addr = seen_base_addresses
self.seen_base_addresses = seen_base_addresses self.seen_base_addresses = seen_base_addresses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment