Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: Zh...
parent 32e45636
...@@ -116,6 +116,11 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -116,6 +116,11 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
model_kwargs["attention_config"] = {"backend": attn_backend.backend.name} model_kwargs["attention_config"] = {"backend": attn_backend.backend.name}
model_kwargs["tensor_parallel_size"] = tp_size model_kwargs["tensor_parallel_size"] = tp_size
# Cap warmup memory: tests use small max_model_len (1024) but the
# engine default max_num_batched_tokens is 16384. Warming up large
# models (e.g. Llama-4-Scout-FP8) at 16384 tokens may trigger OOM.
model_kwargs.setdefault("max_num_batched_tokens", 8192)
# Sparse MLA models (DSv3.2) hit an over-strict inductor assertion in # Sparse MLA models (DSv3.2) hit an over-strict inductor assertion in
# decompose_auto_functionalized when +rotary_embedding is forced into # decompose_auto_functionalized when +rotary_embedding is forced into
# the compile graph. Disable qk_norm+rope fusion (which auto-enables # the compile graph. Disable qk_norm+rope fusion (which auto-enables
......
...@@ -9,8 +9,8 @@ from vllm.platforms import current_platform ...@@ -9,8 +9,8 @@ from vllm.platforms import current_platform
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
_ceil_to_ue8m0, _ceil_to_ue8m0,
calc_diff, calc_diff,
fp8_mqa_logits, fp8_fp4_mqa_logits,
fp8_paged_mqa_logits, fp8_fp4_paged_mqa_logits,
get_num_sms, get_num_sms,
get_paged_mqa_logits_metadata, get_paged_mqa_logits_metadata,
) )
...@@ -127,8 +127,8 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool): ...@@ -127,8 +127,8 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
q_fp8 = q.to(torch.float8_e4m3fn) q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits = fp8_mqa_logits( logits = fp8_fp4_mqa_logits(
q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits (q_fp8, None), kv_fp8, weights, ks, ke, clean_logits=clean_logits
) )
ref_logits = _ref_fp8_mqa_logits( ref_logits = _ref_fp8_mqa_logits(
...@@ -150,7 +150,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool): ...@@ -150,7 +150,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
assert diff < 1e-3, f"{diff=}" assert diff < 1e-3, f"{diff=}"
def _ref_fp8_paged_mqa_logits( def _ref_fp8_fp4_paged_mqa_logits(
q: torch.Tensor, q: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
weights: torch.Tensor, weights: torch.Tensor,
...@@ -205,8 +205,10 @@ def _ref_fp8_paged_mqa_logits( ...@@ -205,8 +205,10 @@ def _ref_fp8_paged_mqa_logits(
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only" not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
) )
@pytest.mark.parametrize("clean_logits", [True, False]) def test_deepgemm_fp8_fp4_paged_mqa_logits():
def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool): # NOTE: clean_logits=True is incompatible with the 2D context_lens
# required by csrc/apis/attention.hpp; only the False path is exercised.
clean_logits = False
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
...@@ -258,21 +260,29 @@ def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool): ...@@ -258,21 +260,29 @@ def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
q_fp8 = q.to(torch.float8_e4m3fn) q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
# deep_gemm paged MQA logits requires 2D context_lens of
# shape (B, next_n) (csrc/apis/attention.hpp:332-335);
# see indexer.py:607-608. For each batch/next_n token, the
# effective context length is context_lens[b] - next_n + j + 1.
next_n_arange = torch.arange(next_n, device="cuda", dtype=torch.int32)
context_lens_2d = (
context_lens.unsqueeze(-1) - next_n + 1 + next_n_arange
).contiguous()
schedule_metadata = get_paged_mqa_logits_metadata( schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms() context_lens_2d, blocksize, get_num_sms()
) )
logits = fp8_paged_mqa_logits( logits = fp8_fp4_paged_mqa_logits(
q_fp8, (q_fp8, None),
kv_cache_fp8, kv_cache_fp8,
weights, weights,
context_lens, context_lens_2d,
block_tables, block_tables,
schedule_metadata, schedule_metadata,
max_model_len, max_model_len,
clean_logits=clean_logits, clean_logits=clean_logits,
) )
ref_logits = _ref_fp8_paged_mqa_logits( ref_logits = _ref_fp8_fp4_paged_mqa_logits(
q, q,
kv_cache, kv_cache,
weights, weights,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Correctness + large-token-count launch tests for fused_q_kv_rmsnorm.
Before the grid-dim fix the kernel used grid ``(2, num_tokens)``, which hit
CUDA's 65535 grid-y cap for ``num_tokens >= 65536`` and failed with
``Triton Error [CUDA]: invalid argument`` at every large chunked-prefill
profile run. These tests pin the new grid layout.
"""
from __future__ import annotations
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.attention.ops.deepseek_v4_ops import fused_q_kv_rmsnorm
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="fused_q_kv_rmsnorm requires a CUDA/ROCm device",
)
def _ref_rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
x_f32 = x.to(torch.float32)
variance = x_f32.pow(2).mean(dim=-1, keepdim=True)
y = x_f32 * torch.rsqrt(variance + eps) * w.to(torch.float32)
return y.to(x.dtype)
@pytest.mark.parametrize("num_tokens", [1, 17, 1024, 8192])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_fused_q_kv_rmsnorm_correctness(num_tokens: int, dtype: torch.dtype):
torch.manual_seed(0)
device = "cuda"
q_size, kv_size = 192, 576
qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device)
kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device)
qw = torch.randn(q_size, dtype=dtype, device=device)
kvw = torch.randn(kv_size, dtype=dtype, device=device)
eps = 1e-6
qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, eps)
qr_ref = _ref_rmsnorm(qr, qw, eps)
kv_ref = _ref_rmsnorm(kv, kvw, eps)
tol = dict(rtol=1e-2, atol=1e-2)
torch.testing.assert_close(qr_out, qr_ref, **tol)
torch.testing.assert_close(kv_out, kv_ref, **tol)
@pytest.mark.parametrize("num_tokens", [65535, 65536, 131072])
def test_fused_q_kv_rmsnorm_launches_past_grid_y_cap(num_tokens: int):
"""Regression guard: grid used to be (2, num_tokens), hitting CUDA's
65535 grid-y cap at num_tokens >= 65536. The new grid (num_tokens, 2)
lifts that bound to 2**31-1."""
device = "cuda"
dtype = torch.bfloat16
q_size, kv_size = 192, 576
qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device)
kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device)
qw = torch.randn(q_size, dtype=dtype, device=device)
kvw = torch.randn(kv_size, dtype=dtype, device=device)
qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, 1e-6)
# spot-check a couple of rows against the torch reference
for row in (0, num_tokens // 2, num_tokens - 1):
torch.testing.assert_close(
qr_out[row],
_ref_rmsnorm(qr[row : row + 1], qw, 1e-6)[0],
rtol=1e-2,
atol=1e-2,
)
torch.testing.assert_close(
kv_out[row],
_ref_rmsnorm(kv[row : row + 1], kvw, 1e-6)[0],
rtol=1e-2,
atol=1e-2,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Unit-test DeepGEMM FP8 kernels (no DeepEP). Unit-test DeepGEMM FP8 and FP4 kernels (no DeepEP).
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts. Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
""" """
...@@ -21,6 +21,8 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import ( ...@@ -21,6 +21,8 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize, maybe_make_prepare_finalize,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
FusedMoEQuantDesc,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
...@@ -204,3 +206,195 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i ...@@ -204,3 +206,195 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
f"DeepGEMM path was not executed during the test. " f"DeepGEMM path was not executed during the test. "
f"Call counter: {call_counter['cnt']}" f"Call counter: {call_counter['cnt']}"
) )
# ---------------------------------------------------------------------------
# FP4 weight tests (DeepGEMM m_grouped_fp8_fp4_gemm_nt_contiguous)
# ---------------------------------------------------------------------------
def make_mxfp4_weights(
e: int,
n: int,
k: int,
):
"""
Generate (w1, w2) expert weights in MXFP4 packed format with float32 scales,
plus BF16 reference weights for validation.
w1 shape: (E, 2N, K//2) uint8 — packed FP4
w2 shape: (E, K, N//2) uint8 — packed FP4
w1_s shape: (E, 2N, K//32) float32 — per-row block-32 scales
w2_s shape: (E, K, N//32) float32 — per-row block-32 scales
w1_bf16: (E, 2N, K) — original BF16 for reference
w2_bf16: (E, K, N) — original BF16 for reference
"""
from deep_gemm.utils.math import per_token_cast_to_fp4
dtype = torch.bfloat16
gran_k = 32 # MXFP4 block size
# bf16 reference weights — scale by 1/sqrt(dim) for numerical stability
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) * (k**-0.5)
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) * (n**-0.5)
# Quantize per-expert to FP4
w1 = torch.empty(e, 2 * n, k // 2, device="cuda", dtype=torch.uint8)
w2 = torch.empty(e, k, n // 2, device="cuda", dtype=torch.uint8)
w1_s = torch.empty(
e, 2 * n, math.ceil(k / gran_k), device="cuda", dtype=torch.float32
)
w2_s = torch.empty(e, k, math.ceil(n / gran_k), device="cuda", dtype=torch.float32)
for i in range(e):
w1[i], w1_s[i] = per_token_cast_to_fp4(
w1_bf16[i].float(), use_ue8m0=True, gran_k=gran_k
)
w2[i], w2_s[i] = per_token_cast_to_fp4(
w2_bf16[i].float(), use_ue8m0=True, gran_k=gran_k
)
return w1, w2, w1_s, w2_s, w1_bf16, w2_bf16
def _bf16_moe_reference(x, w1, w2, topk_weights, topk_ids):
"""BF16 token-loop MoE reference for correctness testing."""
import torch.nn.functional as F
num_tokens, hidden_size = x.shape
intermediate = w1.shape[1] // 2
top_k = topk_ids.shape[1]
output = torch.zeros(num_tokens, hidden_size, dtype=torch.float32, device=x.device)
for t in range(num_tokens):
for kk in range(top_k):
e = topk_ids[t, kk].item()
w = topk_weights[t, kk].item()
fc1 = x[t : t + 1].float() @ w1[e].float().T
linear = fc1[:, :intermediate]
gate = fc1[:, intermediate:]
act = F.silu(gate) * linear
fc2 = act @ w2[e].float().T
output[t] += w * fc2[0]
return output.to(torch.bfloat16)
def run_single_fp4_case(m, n, k, topk, num_experts):
"""
Run one (M,N,K) configuration with FP4 weights on DeepGEMM and assert
DeepGEMM FP4 == BF16 reference within tolerance.
"""
tokens_bf16 = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) * (k**-0.5)
# FP4 expert weight tensors + BF16 originals for reference
w1, w2, w1_s, w2_s, w1_bf16, w2_bf16 = make_mxfp4_weights(num_experts, n, k)
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform
_fp8_dtype = current_platform.fp8_dtype()
_block_shape = GroupShape(128, 128)
quant_config = FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_s, None, None, None),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_s, None, None, None),
)
moe_config = make_dummy_moe_config()
from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import (
DeepGemmFP4Experts,
)
deep_gemm_fp4_experts = mk.FusedMoEKernel(
prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=DeepGemmFP4Experts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
# DeepGEMM FP4 path
out_deepgemm_fp4 = deep_gemm_fp4_experts.apply(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=num_experts,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
)
# BF16 reference using the same original weights
out_ref = _bf16_moe_reference(tokens_bf16, w1_bf16, w2_bf16, topk_weights, topk_ids)
# FP4 vs BF16 reference: quantization error from FP4 weights + FP8 activations
diff = calc_diff(out_deepgemm_fp4, out_ref)
assert diff < 0.05, f"FP4 diff exceeded 5%: {diff}"
# DeepSeek V4 dims: H=4096, I=2048, so N=2*I=4096, K=H=4096.
# FP4 quantization with block_k=32 needs large K for good accuracy.
FP4_MNKs = [
(128, 4096, 4096), # DeepSeek V4 shape
(256, 2048, 2048), # Half-size variant
]
FP4_TOPKS = [2]
FP4_NUM_EXPERTS = [8]
@pytest.mark.parametrize(("m", "n", "k"), FP4_MNKs)
@pytest.mark.parametrize("topk", FP4_TOPKS)
@pytest.mark.parametrize("num_experts", FP4_NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
def test_deepgemm_fp4_vs_triton(
m, n, k, topk, num_experts, monkeypatch, workspace_init
):
pytest.importorskip("deep_gemm.utils.math")
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_DeepGemmFP4Experts = importlib.import_module(
"vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe"
).DeepGemmFP4Experts
call_counter = {"cnt": 0}
orig_fn = _DeepGemmFP4Experts.apply
def _spy_apply(*args, **kwargs):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_DeepGemmFP4Experts, "apply", _spy_apply)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
run_single_fp4_case(
m=m,
n=n,
k=k,
topk=topk,
num_experts=num_experts,
)
# ensure that the DeepGEMM FP4 path was indeed taken.
assert call_counter["cnt"] == 1, (
f"DeepGEMM FP4 path was not executed during the test. "
f"Call counter: {call_counter['cnt']}"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.platforms import current_platform
def _torch_topk_softplus_sqrt(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
routed_scaling_factor: float,
e_score_correction_bias: torch.Tensor | None = None,
input_ids: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
):
scores = F.softplus(gating_output.float()).sqrt()
original_scores = scores
if e_score_correction_bias is not None:
scores_for_choice = scores + e_score_correction_bias.unsqueeze(0)
else:
scores_for_choice = scores
if hash_indices_table is not None:
assert input_ids is not None
topk_ids = hash_indices_table[input_ids.long()]
else:
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=True)[1]
topk_weights = original_scores.gather(1, topk_ids.long())
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method():
assert (
get_routing_method_type(
scoring_func="sqrtsoftplus",
top_k=8,
renormalize=True,
num_expert_group=None,
has_e_score_bias=True,
)
== RoutingMethodType.DeepseekV4
)
assert (
get_routing_method_type(
scoring_func="sqrtsoftplus",
top_k=8,
renormalize=False,
num_expert_group=None,
has_e_score_bias=True,
)
== RoutingMethodType.Unspecified
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [128, 256, 384, 512])
@pytest.mark.parametrize("topk", [6, 8, 16])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.5])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk_softplus_sqrt(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
routed_scaling_factor: float,
dtype: torch.dtype,
):
torch.manual_seed(0)
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(num_experts,), dtype=torch.float32, device="cuda"
)
topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=gating_output,
scoring_func="sqrtsoftplus",
e_score_correction_bias=e_score_correction_bias,
topk=topk,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
)
# Different kernels may return the topk experts in different orders when
# scores tie; sort by expert id before comparing.
sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1)
sorted_ids, idx_ops = topk_ids.sort(dim=-1)
torch.testing.assert_close(sorted_ref_ids, sorted_ids, atol=0, rtol=0)
sorted_w_ref = topk_weights_ref.gather(1, idx_ref)
sorted_w = topk_weights.gather(1, idx_ops)
torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [256, 384, 512])
@pytest.mark.parametrize("topk", [6, 8, 16])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk_softplus_sqrt_hash(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
routed_scaling_factor: float,
dtype: torch.dtype,
):
torch.manual_seed(0)
vocab_size = 1024
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
# Per-token fixed expert selection: for each vocab id pick `topk` distinct
# experts.
hash_indices_table = torch.stack(
[torch.randperm(num_experts)[:topk] for _ in range(vocab_size)]
).to(device="cuda", dtype=torch.int32)
input_ids = torch.randint(
0, vocab_size, (num_tokens,), dtype=torch.int32, device="cuda"
)
topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
input_ids=input_ids,
hash_indices_table=hash_indices_table,
)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=gating_output,
scoring_func="sqrtsoftplus",
e_score_correction_bias=None,
topk=topk,
renormalize=renormalize,
input_tokens=input_ids,
hash_indices_table=hash_indices_table,
routed_scaling_factor=routed_scaling_factor,
)
sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1)
sorted_ids, idx_ops = topk_ids.sort(dim=-1)
torch.testing.assert_close(sorted_ref_ids, sorted_ids, atol=0, rtol=0)
sorted_w_ref = topk_weights_ref.gather(1, idx_ref)
sorted_w = topk_weights.gather(1, idx_ops)
torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant.
Two paths tested:
A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64
B) Indexer: head_dim=128 (all FP8), quant_block=128
These serve as golden references for validating the future fused
compressor+quant+cache kernel.
"""
import math
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.v1.attention.ops.deepseek_v4_ops import (
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float):
"""PyTorch reference for UE8M0 FP8 quantization (per-block, power-of-2 scale).
Returns (x_fp8, scales) where x_fp8 is float8_e4m3fn and scales are float32.
"""
assert x.dim() == 1
n = x.numel()
n_blocks = math.ceil(n / block_size)
x_fp8 = torch.zeros(n, dtype=torch.float8_e4m3fn, device=x.device)
scales = torch.zeros(n_blocks, dtype=torch.float32, device=x.device)
for i in range(n_blocks):
start = i * block_size
end = min(start + block_size, n)
block = x[start:end].float()
amax = block.abs().max().clamp(min=1e-4)
raw_scale = amax / fp8_max
exponent = math.ceil(math.log2(raw_scale.item()))
scale = 2.0**exponent
scales[i] = scale
quantized = (block / scale).clamp(-fp8_max, fp8_max)
x_fp8[start:end] = quantized.to(torch.float8_e4m3fn)
return x_fp8, scales
# ── Test A: DeepseekV4 Attention path ──────────────────────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
@pytest.mark.parametrize("block_size", [16, 64])
def test_deepseek_v4_attention_quant_cache_roundtrip(num_tokens: int, block_size: int):
"""compressed_kv → quantize_and_insert_k_cache → dequantize_and_gather_k_cache
→ compare against original."""
HEAD_DIM = 512
NOPE_DIM = 448
HEAD_BYTES = 584 # 448 fp8 + 128 bf16 + 8 uint8 scale
FP8_MAX = 448.0
QUANT_BLOCK = 64
num_blocks = (num_tokens + block_size - 1) // block_size + 1
device = "cuda"
# Random compressed_kv (simulates compressor output)
compressed_kv = torch.randn(
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
# ── Quant + insert ──────────────────────────────────────────────────
k_cache = torch.zeros(
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
)
k_cache_2d = k_cache.view(num_blocks, -1)
# Sequential slot mapping: token i → slot i
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
quantize_and_insert_k_cache(
compressed_kv, k_cache_2d, slot_mapping, block_size=block_size
)
# ── Gather + dequant ────────────────────────────────────────────────
num_reqs = 1
max_blocks_per_seq = num_blocks
out = torch.zeros(
num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
# block_table: request 0 uses physical blocks 0, 1, ...
block_table = torch.arange(
max_blocks_per_seq, dtype=torch.int32, device=device
).unsqueeze(0)
dequantize_and_gather_k_cache(
out, k_cache, seq_lens, None, block_table, block_size, offset=0
)
recovered = out[0, :num_tokens]
# ── NoPE portion (first 448): FP8 quantized, expect UE8M0 error ──
nope_orig = compressed_kv[:, :NOPE_DIM].float()
nope_recv = recovered[:, :NOPE_DIM].float()
nope_diff = (nope_recv - nope_orig).abs()
# Per-token check: FP8 e4m3 (3-bit mantissa) worst-case error is
# half-ULP at the largest representable value. At y ≈ 448 (max),
# ULP = 2^(8-3) = 32, so error ≤ 16 * scale.
for t in range(num_tokens):
_, scales = _ue8m0_reference(
compressed_kv[t, :NOPE_DIM].float(), QUANT_BLOCK, FP8_MAX
)
max_allowed = 16.0 * scales.max().item()
token_diff = nope_diff[t].max().item()
assert token_diff <= max_allowed, (
f"Token {t} nope diff {token_diff} exceeds max_allowed "
f"{max_allowed} (scale={scales.max().item()})"
)
# ── RoPE portion (last 64): stored as bf16, should be exact ─────
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs()
assert rope_diff.max().item() == 0.0, (
f"RoPE portion should be exact but got max diff {rope_diff.max().item()}"
)
# ── Test B: Indexer path ────────────────────────────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
@pytest.mark.parametrize("block_size", [16, 64])
def test_indexer_quant_cache_roundtrip(num_tokens: int, block_size: int):
"""k → indexer_k_quant_and_cache → cp_gather_indexer_k_quant_cache
→ manual dequant → compare against original."""
HEAD_DIM = 128
QUANT_BLOCK_SIZE = 128
# cache_stride = head_dim + (head_dim * 4 / quant_block_size) = 128 + 4 = 132
CACHE_STRIDE = HEAD_DIM + HEAD_DIM * 4 // QUANT_BLOCK_SIZE
num_blocks = (num_tokens + block_size - 1) // block_size + 1
device = "cuda"
# Random K (simulates compressor output for indexer)
k = torch.randn(num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
# ── Quant + insert ──────────────────────────────────────────────────
kv_cache = torch.zeros(
num_blocks, block_size, CACHE_STRIDE, dtype=torch.uint8, device=device
)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, QUANT_BLOCK_SIZE, "ue8m0")
# ── Gather ──────────────────────────────────────────────────────────
max_blocks_per_seq = num_blocks
block_table = torch.arange(
max_blocks_per_seq, dtype=torch.int32, device=device
).unsqueeze(0)
cu_seq_lens = torch.tensor([0, num_tokens], dtype=torch.int32, device=device)
# dst_k: [total_seq_len, head_dim] as uint8 (raw FP8 bytes)
dst_k = torch.zeros(num_tokens, HEAD_DIM, dtype=torch.uint8, device=device)
# dst_scale: [total_seq_len, head_dim/quant_block*4] as uint8 (raw float32 bytes)
num_scale_bytes = HEAD_DIM * 4 // QUANT_BLOCK_SIZE # 4
dst_scale = torch.zeros(
num_tokens, num_scale_bytes, dtype=torch.uint8, device=device
)
ops.cp_gather_indexer_k_quant_cache(
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
)
# ── Manual dequant ──────────────────────────────────────────────────
k_fp8 = dst_k.view(torch.float8_e4m3fn).float() # [num_tokens, 128]
scale = dst_scale.view(torch.float32) # [num_tokens, 1]
k_recovered = k_fp8 * scale # [num_tokens, 128]
# ── Compare ─────────────────────────────────────────────────────────
diff = (k_recovered - k.float()).abs()
k_abs = k.float().abs()
for t in range(num_tokens):
amax = k_abs[t].max().clamp(min=1e-4).item()
# UE8M0: scale = 2^ceil(log2(amax / 448))
exponent = math.ceil(math.log2(amax / 448.0))
ue8m0_scale = 2.0**exponent
# FP8 e4m3 (3-bit mantissa): worst-case error = 16 * scale
max_allowed = 16.0 * ue8m0_scale
token_diff = diff[t].max().item()
assert token_diff <= max_allowed, (
f"Token {t} diff {token_diff} exceeds max_allowed "
f"{max_allowed} (scale={ue8m0_scale})"
)
def test_indexer_gather_accepts_upper_bound_output():
"""Gather only exact cu_seq_lens even when dst is over-allocated."""
head_dim = 128
quant_block_size = 128
cache_stride = head_dim + head_dim * 4 // quant_block_size
valid_tokens = 9
upper_bound_tokens = 13
block_size = 16
num_blocks = 2
sentinel = 123
device = "cuda"
k = torch.randn(valid_tokens, head_dim, dtype=torch.bfloat16, device=device)
kv_cache = torch.zeros(
num_blocks, block_size, cache_stride, dtype=torch.uint8, device=device
)
slot_mapping = torch.arange(valid_tokens, dtype=torch.int64, device=device)
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, quant_block_size, "ue8m0")
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
0
)
cu_seq_lens = torch.tensor([0, valid_tokens], dtype=torch.int32, device=device)
dst_k = torch.full(
(upper_bound_tokens, head_dim), sentinel, dtype=torch.uint8, device=device
)
num_scale_bytes = head_dim * 4 // quant_block_size
dst_scale = torch.full(
(upper_bound_tokens, num_scale_bytes),
sentinel,
dtype=torch.uint8,
device=device,
)
ops.cp_gather_indexer_k_quant_cache(
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
)
torch.accelerator.synchronize()
k_recovered = dst_k[:valid_tokens].view(torch.float8_e4m3fn).float() * dst_scale[
:valid_tokens
].view(torch.float32)
diff = (k_recovered - k.float()).abs()
max_allowed = (16.0 * dst_scale[:valid_tokens].view(torch.float32).max()).item()
assert diff.max().item() <= max_allowed
assert torch.all(dst_k[valid_tokens:] == sentinel)
assert torch.all(dst_scale[valid_tokens:] == sentinel)
# ── Test C: DeepseekV4 attention with values at different magnitudes ───────────
def test_deepseek_v4_quant_magnitude_range():
"""Test that quantization handles a range of magnitudes correctly."""
HEAD_DIM = 512
NOPE_DIM = 448
HEAD_BYTES = 584
block_size = 16
num_tokens = 4
num_blocks = 2
device = "cuda"
# Create inputs with varying magnitudes: small, medium, large
compressed_kv = torch.zeros(
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
compressed_kv[0] = 0.001 # very small
compressed_kv[1] = 1.0 # unit scale
compressed_kv[2] = 100.0 # large
compressed_kv[3] = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
k_cache = torch.zeros(
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
quantize_and_insert_k_cache(
compressed_kv, k_cache.view(num_blocks, -1), slot_mapping, block_size
)
out = torch.zeros(1, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
0
)
dequantize_and_gather_k_cache(
out, k_cache, seq_lens, None, block_table, block_size, offset=0
)
recovered = out[0, :num_tokens]
# RoPE portion must be exact
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs().max()
assert rope_diff.item() == 0.0, f"RoPE diff {rope_diff.item()}"
# NoPE: relative error should be reasonable
for t in range(num_tokens):
orig = compressed_kv[t, :NOPE_DIM].float()
recv = recovered[t, :NOPE_DIM].float()
abs_diff = (recv - orig).abs().max().item()
magnitude = orig.abs().max().item()
if magnitude > 0.01:
rel_err = abs_diff / magnitude
assert rel_err < 0.15, (
f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, "
f"magnitude={magnitude:.4f}"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Standalone unit test for the horizontally-fused DeepseekV4-MLA kernel:
fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert
- Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last 64 dims
- KV side: GPT-J RoPE on last 64 + UE8M0 FP8 quant + paged cache insert
We compare against:
- PyTorch reference for RMSNorm + GPT-J RoPE on Q
- Existing Triton `quantize_and_insert_k_cache` + round-trip via
`dequantize_and_gather_k_cache` for KV
The kernel is imported via
`torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert`.
"""
import pytest
import torch
from vllm.v1.attention.ops.deepseek_v4_ops import (
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
# ── Constants matching the kernel ────────────────────────────────────────────
HEAD_DIM = 512
ROPE_DIM = 64
NOPE_DIM = HEAD_DIM - ROPE_DIM # 448
QUANT_BLOCK = 64
FP8_MAX = 448.0
HEAD_BYTES = NOPE_DIM + ROPE_DIM * 2 + 8 # 448 + 128 + 8 = 584
# ── PyTorch reference implementations ────────────────────────────────────────
def make_cos_sin_cache(max_pos: int, rope_dim: int, dtype, device):
"""Build a cos||sin cache matching DeepseekV4ScalingRotaryEmbedding layout.
cos_sin_cache[pos, :rope_dim/2] = cos(theta), [rope_dim/2:] = sin(theta).
"""
base = 10000.0
inv_freq = 1.0 / (
base
** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=device) / rope_dim)
)
t = torch.arange(max_pos, dtype=torch.float32, device=device)
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rope_dim/2]
cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1) # [max_pos, rope_dim]
return cache.to(dtype)
def apply_rope_gptj_last_k(
x: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor
) -> torch.Tensor:
"""GPT-J-style (interleaved-pair) RoPE on the LAST rope_dim elements.
x: [..., head_dim] float32
positions: [num_tokens] int64 (positions[i] corresponds to x[i, ...])
cos_sin_cache: [max_pos, rope_dim] float (cos|sin layout)
Returns rotated x (same shape/dtype).
"""
rope_dim = cos_sin_cache.shape[-1]
half = rope_dim // 2
head_dim = x.shape[-1]
nope_dim = head_dim - rope_dim
# Gather cos/sin for each token position: [num_tokens, rope_dim]
cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim]
cos = cs[..., :half] # [N, half]
sin = cs[..., half:] # [N, half]
# Reshape leading dims so we can broadcast: x shape [..., head_dim].
# Bring token dim to front; assume x is [num_tokens, ..., head_dim].
# We rely on positions being per-token and all other dims sharing the same pos.
rope = x[..., nope_dim:].float() # [..., rope_dim]
# Make rope pairs: reshape last dim to [half, 2]
shape = rope.shape
rope = rope.reshape(*shape[:-1], half, 2)
even = rope[..., 0] # [..., half]
odd = rope[..., 1]
# Broadcast cos/sin over any heads dim in between. cos/sin are [N, half].
# Add singleton dims for intermediate axes.
for _ in range(rope.ndim - 3):
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
new_even = even * cos - odd * sin
new_odd = even * sin + odd * cos
rope_rotated = torch.stack((new_even, new_odd), dim=-1).reshape(shape)
out = x.clone().float()
out[..., nope_dim:] = rope_rotated
return out.to(x.dtype)
def rmsnorm_no_weight(x: torch.Tensor, eps: float) -> torch.Tensor:
"""RMSNorm with no learnable weight, matching
`RMSNorm(head_dim, has_weight=False)`."""
orig_dtype = x.dtype
xf = x.float()
variance = xf.pow(2).mean(dim=-1, keepdim=True)
return (xf * torch.rsqrt(variance + eps)).to(orig_dtype)
# ── Dispatch to the CUDA op (skip test cleanly if it isn't built in) ─────────
def _op_available() -> bool:
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or not _op_available(),
reason="CUDA not available or fused DeepseekV4 op not built in",
)
def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs
)
# ── Test 1: Q path numerical parity ──────────────────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64])
@pytest.mark.parametrize("n_heads", [8, 64])
def test_q_path_matches_reference(num_tokens: int, n_heads: int):
torch.manual_seed(0)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
# Reference: RMSNorm (no weight) per head, then GPT-J RoPE on last 64.
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
# Fused call with dummy KV tensors (KV branch will write slot_mapping=-1 → noop).
num_blocks = 2
bs = 16
kv = torch.zeros(num_tokens, HEAD_DIM, dtype=dtype, device=device)
k_cache = torch.zeros(
num_blocks, bs, HEAD_BYTES, dtype=torch.uint8, device=device
).view(num_blocks, -1)
slot_mapping = torch.full((num_tokens,), -1, dtype=torch.int64, device=device)
q_fused = q.clone()
_call_fused(q_fused, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs)
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
# ── Test 2: KV path round-trip byte/value parity ─────────────────────────────
def _ue8m0_per_block_scales(kv_roped_nope_f32: torch.Tensor, qblock: int):
"""Return per-token per-block max scale (used to bound FP8 error)."""
n_tok, nope = kv_roped_nope_f32.shape
n_blocks = nope // qblock
blocks = kv_roped_nope_f32.view(n_tok, n_blocks, qblock)
absmax = blocks.abs().amax(dim=-1).clamp(min=1e-4)
raw = absmax / FP8_MAX
exponent = torch.ceil(torch.log2(raw))
return torch.pow(2.0, exponent) # [n_tok, n_blocks]
@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64])
@pytest.mark.parametrize("block_size", [16, 64])
def test_kv_path_matches_reference(num_tokens: int, block_size: int):
torch.manual_seed(1)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
# ── Reference path: RoPE on kv, then existing Triton quant+insert ──────
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
k_cache_ref = torch.zeros(
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
)
# ── Fused path (dummy q, single head) ──────────────────────────────────
k_cache_fused = torch.zeros_like(k_cache_ref)
q_dummy = torch.zeros(num_tokens, 1, HEAD_DIM, dtype=dtype, device=device)
_call_fused(
q_dummy,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
# ── Round-trip compare via dequant+gather ──────────────────────────────
def _dequant(k_cache_2d):
num_reqs = 1
max_blocks = num_blocks
out = torch.zeros(
num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
block_table = torch.arange(
max_blocks, dtype=torch.int32, device=device
).unsqueeze(0)
# gather_lens arg is None (use seq_lens)
k_cache_3d = k_cache_2d.view(num_blocks, block_size, HEAD_BYTES)
dequantize_and_gather_k_cache(
out, k_cache_3d, seq_lens, None, block_table, block_size, offset=0
)
return out[0, :num_tokens]
recovered_ref = _dequant(k_cache_ref)
recovered_fused = _dequant(k_cache_fused)
# NoPE: per-block UE8M0 FP8 error bound (half-ULP at max = 16 * scale).
scales = _ue8m0_per_block_scales(kv_ref[:, :NOPE_DIM].float(), QUANT_BLOCK)
for t in range(num_tokens):
max_allowed = 16.0 * scales[t].max().item()
diff_ref = (
(recovered_ref[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item()
)
diff_fused = (
(recovered_fused[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item()
)
assert diff_ref <= max_allowed, (
f"ref NoPE token {t} diff {diff_ref} > {max_allowed}"
)
assert diff_fused <= max_allowed, (
f"fused NoPE token {t} diff {diff_fused} > {max_allowed}"
)
# RoPE region: bf16 stored exactly → zero diff.
rope_diff = (recovered_fused[:, NOPE_DIM:] - kv_ref[:, NOPE_DIM:]).abs().max()
assert rope_diff.item() == 0.0, f"RoPE portion not exact: {rope_diff.item()}"
# Exact byte equality of the two cache buffers — strong parity.
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
# ── Test 2b: DP padding (slot_mapping shorter than q/kv) ─────────────────────
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("pad", [1, 5])
@pytest.mark.parametrize("block_size", [16, 64])
def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int):
"""slot_mapping.size(0) < q.size(0): the kernel must skip padded
tokens in the KV branch while still running Q-norm+RoPE on all rows."""
torch.manual_seed(3)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
total = num_tokens + pad
kv = torch.randn(total, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(total, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
# Reference: only the first num_tokens kv rows get inserted.
kv_ref = apply_rope_gptj_last_k(
kv[:num_tokens], positions[:num_tokens], cos_sin_cache
)
k_cache_ref = torch.zeros(
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
)
# Fused: pass full-sized q/kv/positions, shorter slot_mapping.
q_dummy = torch.zeros(total, 1, HEAD_DIM, dtype=dtype, device=device)
k_cache_fused = torch.zeros_like(k_cache_ref)
_call_fused(
q_dummy,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
# ── Test 3: combined single-call Q + KV parity ───────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 17])
@pytest.mark.parametrize("n_heads", [8, 64])
@pytest.mark.parametrize("block_size", [16, 64])
def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):
torch.manual_seed(2)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
# Reference.
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
k_cache_ref = torch.zeros(
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
)
# Fused single call.
q_fused = q.clone()
k_cache_fused = torch.zeros_like(k_cache_ref)
_call_fused(
q_fused,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit test for fused_indexer_q_rope_quant.
Compares the fused Triton kernel against the unfused reference flow used by
the DeepseekV4 indexer in model_tracking:
q_rot = ops.rotary_embedding(positions, q, None, head_dim, cos_sin_cache,
is_neox_style=False,
rope_dim_offset=head_dim - rope_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q_rot, head_dim, use_ue8m0=True)
weights_out = weights * q_scale * softmax_scale * head_scale
Expects bit-exact equality on both q_fp8 and weights_out.
"""
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
fused_indexer_q_rope_quant,
)
HEAD_DIM = 128
ROPE_DIM = 64
N_HEAD = 64
MAX_POS = 4096
def _reference(
positions: torch.Tensor,
q: torch.Tensor,
cos_sin_cache: torch.Tensor,
weights: torch.Tensor,
softmax_scale: float,
head_scale: float,
) -> tuple[torch.Tensor, torch.Tensor]:
q_rot = q.clone()
ops.rotary_embedding(
positions,
q_rot,
None,
HEAD_DIM,
cos_sin_cache,
False, # is_neox_style=False → GPT-J interleaved
HEAD_DIM - ROPE_DIM, # rope_dim_offset → rotate the tail
False,
)
q_fp8, q_scale = per_token_group_quant_fp8(
q_rot.view(-1, HEAD_DIM).contiguous(),
HEAD_DIM,
use_ue8m0=True,
)
q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM)
q_scale = q_scale.view(-1, N_HEAD)
weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale
return q_fp8, weights_out
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 257])
@pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16])
@torch.inference_mode()
def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype):
device = "cuda"
torch.manual_seed(0)
q = torch.randn(num_tokens, N_HEAD, HEAD_DIM, dtype=torch.bfloat16, device=device)
positions = torch.randint(
0, MAX_POS, (num_tokens,), dtype=torch.int64, device=device
)
cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=cache_dtype, device=device)
weights = torch.randn(num_tokens, N_HEAD, dtype=torch.bfloat16, device=device)
softmax_scale = HEAD_DIM**-0.5
head_scale = N_HEAD**-0.5
q_fp8_ref, weights_ref = _reference(
positions, q, cos_sin_cache, weights, softmax_scale, head_scale
)
q_fp8_fused, weights_fused = fused_indexer_q_rope_quant(
positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale
)
# fp8 tensors aren't directly comparable via torch.equal — reinterpret as int8.
ref_bits = q_fp8_ref.view(torch.int8)
fused_bits = q_fp8_fused.view(torch.int8)
assert torch.equal(ref_bits, fused_bits), (
f"q_fp8 mismatch: "
f"{(ref_bits != fused_bits).sum().item()} / {ref_bits.numel()} bytes differ"
)
assert torch.equal(weights_ref, weights_fused), (
f"weights mismatch: max abs diff "
f"{(weights_ref - weights_fused).abs().max().item()}"
)
This diff is collapsed.
...@@ -718,7 +718,6 @@ def test_persistent_topk_stress() -> None: ...@@ -718,7 +718,6 @@ def test_persistent_topk_stress() -> None:
pytest.param( pytest.param(
{ {
"seq_lens": [2000, 6000, 30000, 80000], "seq_lens": [2000, 6000, 30000, 80000],
"top_k": 2048,
"data_type": "random", "data_type": "random",
}, },
id="mixed_all_paths", id="mixed_all_paths",
...@@ -727,7 +726,6 @@ def test_persistent_topk_stress() -> None: ...@@ -727,7 +726,6 @@ def test_persistent_topk_stress() -> None:
pytest.param( pytest.param(
{ {
"seq_lens": [2048, 4096, 8192, 16000], "seq_lens": [2048, 4096, 8192, 16000],
"top_k": 2048,
"data_type": "random", "data_type": "random",
}, },
id="all_decode_medium", id="all_decode_medium",
...@@ -736,7 +734,6 @@ def test_persistent_topk_stress() -> None: ...@@ -736,7 +734,6 @@ def test_persistent_topk_stress() -> None:
pytest.param( pytest.param(
{ {
"seq_lens": [70000, 100000, 163840], "seq_lens": [70000, 100000, 163840],
"top_k": 2048,
"data_type": "random", "data_type": "random",
}, },
id="all_large", id="all_large",
...@@ -745,7 +742,6 @@ def test_persistent_topk_stress() -> None: ...@@ -745,7 +742,6 @@ def test_persistent_topk_stress() -> None:
pytest.param( pytest.param(
{ {
"seq_lens": [32767, 32768, 32769, 32772], "seq_lens": [32767, 32768, 32769, 32772],
"top_k": 2048,
"data_type": "random", "data_type": "random",
}, },
id="large_threshold_boundary", id="large_threshold_boundary",
...@@ -754,7 +750,6 @@ def test_persistent_topk_stress() -> None: ...@@ -754,7 +750,6 @@ def test_persistent_topk_stress() -> None:
pytest.param( pytest.param(
{ {
"seq_lens": [5000], "seq_lens": [5000],
"top_k": 2048,
"data_type": "random", "data_type": "random",
}, },
id="single_row_medium", id="single_row_medium",
...@@ -772,15 +767,15 @@ def test_persistent_topk_stress() -> None: ...@@ -772,15 +767,15 @@ def test_persistent_topk_stress() -> None:
pytest.param( pytest.param(
{ {
"seq_lens": [100, 2048, 10000, 80000], "seq_lens": [100, 2048, 10000, 80000],
"top_k": 2048,
"data_type": "random", "data_type": "random",
}, },
id="trivial_medium_large_mix", id="trivial_medium_large_mix",
), ),
], ],
) )
@pytest.mark.parametrize("top_k", [512, 2048])
@torch.inference_mode() @torch.inference_mode()
def test_persistent_topk(test_config: dict) -> None: def test_persistent_topk(test_config: dict, top_k: int) -> None:
""" """
Tests specific to the persistent_topk kernel: Tests specific to the persistent_topk kernel:
- Mixed medium/large rows in the same batch (dynamic per-row dispatch) - Mixed medium/large rows in the same batch (dynamic per-row dispatch)
...@@ -790,14 +785,15 @@ def test_persistent_topk(test_config: dict) -> None: ...@@ -790,14 +785,15 @@ def test_persistent_topk(test_config: dict) -> None:
run_large_context_topk_test( run_large_context_topk_test(
batch_size=len(test_config["seq_lens"]), batch_size=len(test_config["seq_lens"]),
seq_lens=test_config["seq_lens"], seq_lens=test_config["seq_lens"],
top_k=test_config["top_k"], top_k=top_k,
data_type=test_config.get("data_type", "random"), data_type=test_config.get("data_type", "random"),
) )
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize("top_k", [512, 2048])
@torch.inference_mode() @torch.inference_mode()
def test_persistent_topk_padded_stride() -> None: def test_persistent_topk_padded_stride(top_k: int) -> None:
""" """
Test persistent_topk with padded logits (large stride, small seq_len) Test persistent_topk with padded logits (large stride, small seq_len)
to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits
...@@ -806,7 +802,6 @@ def test_persistent_topk_padded_stride() -> None: ...@@ -806,7 +802,6 @@ def test_persistent_topk_padded_stride() -> None:
set_random_seed(42) set_random_seed(42)
torch.set_default_device("cuda:0") torch.set_default_device("cuda:0")
top_k = 2048
batch_size = 4 batch_size = 4
padded_stride = 163840 # DeepSeek-V3.2 max_model_len padded_stride = 163840 # DeepSeek-V3.2 max_model_len
actual_seq_lens = [3000, 5000, 8000, 12000] actual_seq_lens = [3000, 5000, 8000, 12000]
......
...@@ -41,7 +41,9 @@ class DummyRouter(BaseRouter): ...@@ -41,7 +41,9 @@ class DummyRouter(BaseRouter):
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.FUSED_TOPK return RoutingMethodType.FUSED_TOPK
def _compute_routing(self, hidden_states, router_logits, indices_type): def _compute_routing(
self, hidden_states, router_logits, indices_type, *, input_ids=None
):
topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64) topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64)
topk_weights = torch.ones_like(topk_ids, dtype=torch.float32) topk_weights = torch.ones_like(topk_ids, dtype=torch.float32)
return topk_weights, topk_ids return topk_weights, topk_ids
......
...@@ -260,6 +260,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -260,6 +260,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
), ),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"), "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"DeepseekV4ForCausalLM": _HfExamplesInfo(
"deepseek-ai/DeepSeek-V4-Flash", is_available_online=False
),
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"), "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"), "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"),
"ExaoneForCausalLM": _HfExamplesInfo( "ExaoneForCausalLM": _HfExamplesInfo(
...@@ -1482,6 +1485,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1482,6 +1485,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="luccafong/deepseek_mtp_draft_random", speculative_model="luccafong/deepseek_mtp_draft_random",
trust_remote_code=True, trust_remote_code=True,
), ),
"DeepSeekV4MTPModel": _HfExamplesInfo(
"deepseek-ai/DeepSeek-V4-Flash",
speculative_model="deepseek-ai/DeepSeek-V4-Flash",
trust_remote_code=True,
is_available_online=False,
),
"ErnieMTPModel": _HfExamplesInfo( "ErnieMTPModel": _HfExamplesInfo(
"baidu/ERNIE-4.5-21B-A3B-PT", "baidu/ERNIE-4.5-21B-A3B-PT",
trust_remote_code=True, trust_remote_code=True,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import pytest
import torch
from vllm.model_executor.models.deepseek_v4 import (
DeepseekV4MegaMoEExperts,
_stage_deepseek_v4_mega_moe_inputs,
make_deepseek_v4_expert_params_mapping,
)
from vllm.platforms import current_platform
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda(),
reason="DeepSeek V4 MegaMoE requires CUDA",
)
def test_deepseek_v4_mega_moe_expert_mapping():
mapping = make_deepseek_v4_expert_params_mapping(2)
assert mapping == [
("experts.w13_", "experts.0.w1.", 0, "w1"),
("experts.w2_", "experts.0.w2.", 0, "w2"),
("experts.w13_", "experts.0.w3.", 0, "w3"),
("experts.w13_", "experts.1.w1.", 1, "w1"),
("experts.w2_", "experts.1.w2.", 1, "w2"),
("experts.w13_", "experts.1.w3.", 1, "w3"),
]
def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float():
raw = torch.tensor([0, 126, 127, 128], dtype=torch.uint8)
decoded = DeepseekV4MegaMoEExperts._ue8m0_uint8_to_float(raw)
assert torch.equal(decoded.view(torch.int32), raw.to(torch.int32) << 23)
assert decoded[0].item() == 0.0
assert decoded[1].item() == 0.5
assert decoded[2].item() == 1.0
assert decoded[3].item() == 2.0
def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership():
vllm_config = SimpleNamespace(
scheduler_config=SimpleNamespace(max_num_batched_tokens=4)
)
experts = DeepseekV4MegaMoEExperts(
vllm_config,
num_experts=4,
num_local_experts=2,
experts_start_idx=2,
top_k=2,
hidden_size=128,
intermediate_size=128,
)
nonlocal_weight = torch.ones(128, 64, dtype=torch.uint8)
assert (
experts.weight_loader(
experts.w13_weight,
nonlocal_weight,
"experts.w13_weight",
shard_id="w1",
expert_id=1,
return_success=True,
)
is False
)
w1 = torch.full((128, 64), 3, dtype=torch.uint8)
w3 = torch.full((128, 64), 7, dtype=torch.uint8)
w2 = torch.full((128, 64), 11, dtype=torch.uint8)
assert experts.weight_loader(
experts.w13_weight,
w1,
"experts.w13_weight",
shard_id="w1",
expert_id=2,
return_success=True,
)
assert experts.weight_loader(
experts.w13_weight,
w3,
"experts.w13_weight",
shard_id="w3",
expert_id=2,
return_success=True,
)
assert experts.weight_loader(
experts.w2_weight,
w2,
"experts.w2_weight",
shard_id="w2",
expert_id=2,
return_success=True,
)
assert torch.equal(experts.w13_weight[0, :128], w1)
assert torch.equal(experts.w13_weight[0, 128:], w3)
assert torch.equal(experts.w2_weight[0], w2)
assert torch.count_nonzero(experts.w13_weight[1]) == 0
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.",
)
def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact():
from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8
device = torch.device("cuda")
num_tokens = 7
hidden_size = 256
top_k = 8
generator = torch.Generator(device=device)
generator.manual_seed(0)
hidden_states = (
torch.randn(
num_tokens,
hidden_size,
device=device,
dtype=torch.float32,
generator=generator,
)
* 17.0
).to(torch.bfloat16)
hidden_states[0, :32] = 0
hidden_states[1, 32:64] = 1.0e-6
hidden_states[2, 64:96] = -1.0e-6
topk_ids = torch.randint(
0,
256,
(num_tokens, top_k),
device=device,
dtype=torch.int32,
generator=generator,
)
topk_weights = torch.randn(
num_tokens,
top_k,
device=device,
dtype=torch.float32,
generator=generator,
)
ref_x, ref_x_sf = per_token_cast_to_fp8(
hidden_states,
use_ue8m0=True,
gran_k=32,
use_packed_ue8m0=True,
)
ref_topk_idx = topk_ids.to(torch.int64)
ref_topk_weights = topk_weights.clone()
fused_x = torch.empty_like(ref_x)
fused_x_sf = torch.empty_like(ref_x_sf)
fused_topk_idx = torch.empty_like(ref_topk_idx)
fused_topk_weights = torch.empty_like(ref_topk_weights)
_stage_deepseek_v4_mega_moe_inputs(
hidden_states,
topk_weights,
topk_ids,
fused_x,
fused_x_sf,
fused_topk_idx,
fused_topk_weights,
)
torch.accelerator.synchronize()
assert torch.equal(fused_x.view(torch.uint8), ref_x.view(torch.uint8))
assert torch.equal(fused_x_sf, ref_x_sf)
assert torch.equal(fused_topk_idx, ref_topk_idx)
assert torch.equal(
fused_topk_weights.view(torch.uint8),
ref_topk_weights.view(torch.uint8),
)
...@@ -6,6 +6,7 @@ from transformers import AutoTokenizer ...@@ -6,6 +6,7 @@ from transformers import AutoTokenizer
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning import ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
...@@ -33,6 +34,12 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type): ...@@ -33,6 +34,12 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type):
assert isinstance(parser._parser, expected_parser_type) assert isinstance(parser._parser, expected_parser_type)
def test_deepseek_v4_reasoning_parser_alias():
parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4")
assert parser_cls is DeepSeekV3ReasoningParser
def test_identity_reasoning_parser_basic(tokenizer): def test_identity_reasoning_parser_basic(tokenizer):
parser = IdentityReasoningParser(tokenizer) parser = IdentityReasoningParser(tokenizer)
......
{
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["location"]
}
}
},
{
"type": "function",
"function": {
"name": "search",
"description": "Search the web for information",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query"
},
"num_results": {
"type": "integer",
"description": "Number of results to return"
}
},
"required": ["query"]
}
}
}
],
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What's the weather in Beijing?"
},
{
"role": "assistant",
"reasoning": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
"tool_calls": [
{
"id": "call_001",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_001",
"content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
},
{
"role": "assistant",
"reasoning": "Got the weather data. Let me format a nice response.",
"content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
}
]
}
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"reasoning": "The user said hello, I should greet back.",
"content": "Hi there! How can I help you?"
},
{
"role": "user",
"content": "What is the capital of France?"
},
{
"role": "assistant",
"reasoning": "The user asks about the capital of France. It is Paris.",
"content": "The capital of France is Paris."
}
]
\ No newline at end of file
[
{
"role": "system",
"content": "该助手为DeepSeek,由深度求索公司创造。"
},
{
"role": "latest_reminder",
"content": "2026-02-21,星期六,广州,App,中文"
},
{
"role": "developer",
"content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Web search. Split multiple queries with '||'.",
"parameters": {
"type": "object",
"properties": {
"queries": {
"type": "string",
"description": "query1||query2"
}
},
"required": [
"queries"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
},
{
"type": "function",
"function": {
"name": "open",
"description": "Batch open IDs (format 【{id}†...】) or URLs.",
"parameters": {
"type": "object",
"properties": {
"open_list": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {
"description": "ID or URL",
"anyOf": [
{
"type": "integer"
},
{
"type": "string"
}
],
"default": -1
},
"cursor": {
"type": "integer",
"description": "",
"default": -1
},
"loc": {
"type": "integer",
"description": "Start line",
"default": -1
},
"num_lines": {
"type": "integer",
"description": "",
"default": -1
},
"view_source": {
"type": "boolean",
"description": "",
"default": false
}
},
"additionalProperties": false
},
"description": ""
}
},
"required": [
"open_list"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
},
{
"type": "function",
"function": {
"name": "find",
"description": "Find exact text pattern in pages.",
"parameters": {
"type": "object",
"properties": {
"find_list": {
"type": "array",
"items": {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": ""
},
"cursor": {
"type": "integer",
"description": "",
"default": -1
}
},
"required": [
"pattern"
],
"additionalProperties": false
},
"description": ""
}
},
"required": [
"find_list"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
}
]
},
{
"role": "assistant",
"content": "",
"reasoning": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
"tool_calls": [
{
"type": "function",
"function": {
"name": "search",
"arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
}
}
]
},
{
"role": "tool",
"content": "[0]"
},
{
"role": "assistant",
"content": "请及时就医。",
"reasoning": "现在开始组织回答。",
"tool_calls": []
}
]
\ No newline at end of file
[
{
"role": "system",
"content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
},
{
"role": "latest_reminder",
"content": "2024-11-15,上海市,App,中文"
},
{
"role": "user",
"content": "热海大滚锅是世界著名温泉吗"
},
{
"role": "assistant",
"content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
"mask": 1
},
{
"role": "user",
"content": "世界著名温泉有哪些",
"task": "action"
},
{
"role": "assistant",
"content": "Search"
}
]
\ No newline at end of file
<|begin▁of▁sentence|>
## Tools
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
<|DSML|tool_calls>
<|DSML|invoke name="$TOOL_NAME">
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
...
</|DSML|invoke>
<|DSML|invoke name="$TOOL_NAME2">
...
</|DSML|invoke>
</|DSML|tool_calls>
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
Otherwise, output directly after </think> with tool calls or final response.
### Available Tool Schemas
{"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
{"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
You are a helpful assistant.<|User|>What's the weather in Beijing?<|Assistant|><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
<|DSML|tool_calls>
<|DSML|invoke name="get_weather">
<|DSML|parameter name="location" string="true">Beijing</|DSML|parameter>
<|DSML|parameter name="unit" string="true">celsius</|DSML|parameter>
</|DSML|invoke>
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><|Assistant|><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|>
\ No newline at end of file
<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|></think>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<|end▁of▁sentence|>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment